Skip to content

Commit de22abe

Browse files
committed
[WIP][SYCL] Add support for union types as kernel arguments
Signed-off-by: Soumi Manna <[email protected]>
1 parent 76ffef7 commit de22abe

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,13 @@ class KernelObjVisitor {
830830
else if (ElementTy->isStructureOrClassType())
831831
VisitRecord(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
832832
handlers...);
833+
else if (ElementTy->isUnionType())
834+
// TODO: This check is still necessary I think?! Array seems to handle
835+
// this differently (see above) for structs I think.
836+
//if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
837+
VisitUnion(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
838+
handlers...);
839+
//}
833840
else if (ElementTy->isArrayType())
834841
VisitArrayElements(ArrayField, ElementTy, handlers...);
835842
else if (ElementTy->isScalarType())
@@ -857,6 +864,41 @@ class KernelObjVisitor {
857864
void VisitRecord(const CXXRecordDecl *Owner, ParentTy &Parent,
858865
const CXXRecordDecl *Wrapper, Handlers &... handlers);
859866

867+
// Base case, only calls these when filtered.
868+
template <typename... FilteredHandlers, typename ParentTy>
869+
void VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
870+
const CXXRecordDecl *Wrapper,
871+
FilteredHandlers &... handlers) {
872+
(void)std::initializer_list<int>{
873+
(handlers.enterUnion(Owner, Parent), 0)...};
874+
VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...);
875+
(void)std::initializer_list<int>{
876+
(handlers.leaveUnion(Owner, Parent), 0)...};
877+
}
878+
879+
880+
template <typename... FilteredHandlers, typename ParentTy,
881+
typename CurHandler, typename... Handlers>
882+
std::enable_if_t<!CurHandler::VisitUnionBody>
883+
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
884+
const CXXRecordDecl *Wrapper,
885+
FilteredHandlers &... filtered_handlers,
886+
CurHandler &cur_handler, Handlers &... handlers) {
887+
VisitUnion<FilteredHandlers...>(
888+
Owner, Parent, Wrapper, filtered_handlers..., handlers...);
889+
}
890+
891+
template <typename... FilteredHandlers, typename ParentTy,
892+
typename CurHandler, typename... Handlers>
893+
std::enable_if_t<CurHandler::VisitUnionBody>
894+
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
895+
const CXXRecordDecl *Wrapper,
896+
FilteredHandlers &... filtered_handlers,
897+
CurHandler &cur_handler, Handlers &... handlers) {
898+
VisitUnion<FilteredHandlers..., CurHandler>(
899+
Owner, Parent, Wrapper, filtered_handlers..., cur_handler, handlers...);
900+
}
901+
860902
template <typename... Handlers>
861903
void VisitRecordHelper(const CXXRecordDecl *Owner,
862904
clang::CXXRecordDecl::base_class_const_range Range,
@@ -942,6 +984,11 @@ class KernelObjVisitor {
942984
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
943985
VisitRecord(Owner, Field, RD, handlers...);
944986
}
987+
} else if (FieldTy->isUnionType()) {
988+
if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
989+
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
990+
VisitUnion(Owner, Field, RD, handlers...);
991+
}
945992
} else if (FieldTy->isReferenceType())
946993
KF_FOR_EACH(handleReferenceType, Field, FieldTy);
947994
else if (FieldTy->isPointerType())
@@ -1005,6 +1052,7 @@ class SyclKernelFieldHandler {
10051052
}
10061053
virtual bool handleSyclHalfType(FieldDecl *, QualType) { return true; }
10071054
virtual bool handleStructType(FieldDecl *, QualType) { return true; }
1055+
virtual bool handleUnionType(FieldDecl *, QualType) { return true; }
10081056
virtual bool handleReferenceType(FieldDecl *, QualType) { return true; }
10091057
virtual bool handlePointerType(FieldDecl *, QualType) { return true; }
10101058
virtual bool handleArrayType(FieldDecl *, QualType) { return true; }
@@ -1024,6 +1072,8 @@ class SyclKernelFieldHandler {
10241072
virtual bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) {
10251073
return true;
10261074
}
1075+
virtual bool enterUnion(const CXXRecordDecl *, FieldDecl *) { return true; }
1076+
virtual bool leaveUnion(const CXXRecordDecl *, FieldDecl *) { return true; }
10271077

10281078
// The following are used for stepping through array elements.
10291079

@@ -1201,6 +1251,65 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
12011251
}
12021252
};
12031253

1254+
// A type to check the validity of passing union with accessor/sampler/stream
1255+
// member as a kernel argument types.
1256+
class SyclKernelUnionBodyChecker : public SyclKernelFieldHandler {
1257+
static constexpr const bool VisitUnionBody = true;
1258+
int UnionCount = 0;
1259+
bool IsInvalid = false;
1260+
DiagnosticsEngine &Diag;
1261+
1262+
public:
1263+
SyclKernelUnionBodyChecker(Sema &S)
1264+
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
1265+
bool isValid() { return !IsInvalid; }
1266+
1267+
bool enterUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
1268+
++UnionCount;
1269+
return true;
1270+
}
1271+
1272+
bool leaveUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
1273+
--UnionCount;
1274+
return true;
1275+
}
1276+
1277+
bool handlePointerType(FieldDecl *FD, QualType FieldTy) final {
1278+
if (UnionCount) {
1279+
IsInvalid = true;
1280+
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
1281+
<< FieldTy;
1282+
}
1283+
return isValid();
1284+
}
1285+
1286+
bool handleSyclAccessorType(FieldDecl *FD, QualType FieldTy) final {
1287+
if (UnionCount) {
1288+
IsInvalid = true;
1289+
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
1290+
<< FieldTy;
1291+
}
1292+
return isValid();
1293+
}
1294+
1295+
bool handleSyclSamplerType(FieldDecl *FD, QualType FieldTy) final {
1296+
if (UnionCount) {
1297+
IsInvalid = true;
1298+
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
1299+
<< FieldTy;
1300+
}
1301+
return isValid();
1302+
}
1303+
bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final {
1304+
if (UnionCount) {
1305+
IsInvalid = true;
1306+
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
1307+
<< FieldTy;
1308+
}
1309+
return isValid();
1310+
}
1311+
};
1312+
12041313
// A type to Create and own the FunctionDecl for the kernel.
12051314
class SyclKernelDeclCreator : public SyclKernelFieldHandler {
12061315
FunctionDecl *KernelDecl;
@@ -1416,6 +1525,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
14161525
return true;
14171526
}
14181527

1528+
bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
1529+
return handleScalarType(FD, FieldTy);
1530+
}
1531+
14191532
bool handleSyclHalfType(FieldDecl *FD, QualType FieldTy) final {
14201533
addParam(FD, FieldTy);
14211534
return true;
@@ -1751,6 +1864,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
17511864
return true;
17521865
}
17531866

1867+
bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
1868+
return handleScalarType(FD, FieldTy);
1869+
}
1870+
17541871
bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
17551872
CXXCastPath BasePath;
17561873
QualType DerivedTy(RD->getTypeForDecl(), 0);
@@ -1955,6 +2072,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
19552072
return true;
19562073
}
19572074

2075+
bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
2076+
return handleScalarType(FD, FieldTy);
2077+
}
2078+
19582079
bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final {
19592080
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
19602081
return true;

0 commit comments

Comments
 (0)