-
Notifications
You must be signed in to change notification settings - Fork 788
[NOT FOR COMMIT] First try at getting unions supported, does type checking but does not #2270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -822,6 +822,13 @@ class KernelObjVisitor { | |
else if (ElementTy->isStructureOrClassType()) | ||
VisitRecord(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(), | ||
handlers...); | ||
else if (ElementTy->isUnionType()) | ||
// TODO: This check is still necessary I think?! Array seems to handle | ||
// this differently (see above) for structs I think. | ||
//if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) { | ||
VisitUnion(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(), | ||
handlers...); | ||
//} | ||
else if (ElementTy->isArrayType()) | ||
VisitArrayElements(ArrayField, ElementTy, handlers...); | ||
else if (ElementTy->isScalarType()) | ||
|
@@ -849,6 +856,41 @@ class KernelObjVisitor { | |
void VisitRecord(const CXXRecordDecl *Owner, ParentTy &Parent, | ||
const CXXRecordDecl *Wrapper, Handlers &... handlers); | ||
|
||
// Base case, only calls these when filtered. | ||
template <typename... FilteredHandlers, typename ParentTy> | ||
void VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent, | ||
const CXXRecordDecl *Wrapper, | ||
FilteredHandlers &... handlers) { | ||
(void)std::initializer_list<int>{ | ||
(handlers.enterUnion(Owner, Parent), 0)...}; | ||
VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...); | ||
(void)std::initializer_list<int>{ | ||
(handlers.leaveUnion(Owner, Parent), 0)...}; | ||
} | ||
|
||
|
||
template <typename... FilteredHandlers, typename ParentTy, | ||
typename CurHandler, typename... Handlers> | ||
std::enable_if_t<!CurHandler::VisitUnionBody> | ||
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent, | ||
const CXXRecordDecl *Wrapper, | ||
FilteredHandlers &... filtered_handlers, | ||
CurHandler &cur_handler, Handlers &... handlers) { | ||
VisitUnion<FilteredHandlers...>( | ||
Owner, Parent, Wrapper, filtered_handlers..., handlers...); | ||
} | ||
|
||
template <typename... FilteredHandlers, typename ParentTy, | ||
typename CurHandler, typename... Handlers> | ||
std::enable_if_t<CurHandler::VisitUnionBody> | ||
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent, | ||
const CXXRecordDecl *Wrapper, | ||
FilteredHandlers &... filtered_handlers, | ||
CurHandler &cur_handler, Handlers &... handlers) { | ||
VisitUnion<FilteredHandlers..., CurHandler>( | ||
Owner, Parent, Wrapper, filtered_handlers..., cur_handler, handlers...); | ||
} | ||
|
||
template <typename... Handlers> | ||
void VisitRecordHelper(const CXXRecordDecl *Owner, | ||
clang::CXXRecordDecl::base_class_const_range Range, | ||
|
@@ -934,6 +976,11 @@ class KernelObjVisitor { | |
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); | ||
VisitRecord(Owner, Field, RD, handlers...); | ||
} | ||
} else if (FieldTy->isUnionType()) { | ||
if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) { | ||
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); | ||
VisitUnion(Owner, Field, RD, handlers...); | ||
} | ||
} else if (FieldTy->isReferenceType()) | ||
KF_FOR_EACH(handleReferenceType, Field, FieldTy); | ||
else if (FieldTy->isPointerType()) | ||
|
@@ -949,6 +996,9 @@ class KernelObjVisitor { | |
(handlers.leaveField(Owner, Field), 0)...}; | ||
} | ||
} | ||
|
||
|
||
|
||
#undef KF_FOR_EACH | ||
}; | ||
// Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter | ||
|
@@ -973,6 +1023,8 @@ class SyclKernelFieldHandler { | |
SyclKernelFieldHandler(Sema &S) : SemaRef(S) {} | ||
|
||
public: | ||
|
||
static const constexpr bool VisitUnionBody = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Default opt-out here. |
||
// Mark these virtual so that we can use override in the implementer classes, | ||
// despite virtual dispatch never being used. | ||
|
||
|
@@ -997,6 +1049,7 @@ class SyclKernelFieldHandler { | |
} | ||
virtual bool handleSyclHalfType(FieldDecl *, QualType) { return true; } | ||
virtual bool handleStructType(FieldDecl *, QualType) { return true; } | ||
virtual bool handleUnionType(FieldDecl *, QualType) { return true; } | ||
virtual bool handleReferenceType(FieldDecl *, QualType) { return true; } | ||
virtual bool handlePointerType(FieldDecl *, QualType) { return true; } | ||
virtual bool handleArrayType(FieldDecl *, QualType) { return true; } | ||
|
@@ -1016,6 +1069,8 @@ class SyclKernelFieldHandler { | |
virtual bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) { | ||
return true; | ||
} | ||
virtual bool enterUnion(const CXXRecordDecl *, FieldDecl *) { return true; } | ||
virtual bool leaveUnion(const CXXRecordDecl *, FieldDecl *) { return true; } | ||
|
||
// The following are used for stepping through array elements. | ||
|
||
|
@@ -1037,6 +1092,7 @@ class SyclKernelFieldHandler { | |
// A type to check the validity of all of the argument types. | ||
class SyclKernelFieldChecker : public SyclKernelFieldHandler { | ||
bool IsInvalid = false; | ||
unsigned UnionCount = 0; | ||
DiagnosticsEngine &Diag; | ||
|
||
// Check whether the object should be disallowed from being copied to kernel. | ||
|
@@ -1079,7 +1135,6 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { | |
void checkAccessorType(QualType Ty, SourceRange Loc) { | ||
assert(Util::isSyclAccessorType(Ty) && | ||
"Should only be called on SYCL accessor types."); | ||
|
||
const RecordDecl *RecD = Ty->getAsRecordDecl(); | ||
if (const ClassTemplateSpecializationDecl *CTSD = | ||
dyn_cast<ClassTemplateSpecializationDecl>(RecD)) { | ||
|
@@ -1093,6 +1148,8 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { | |
} | ||
|
||
public: | ||
static const constexpr bool VisitUnionBody = true; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is how a handler opts-into this. |
||
|
||
SyclKernelFieldChecker(Sema &S) | ||
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {} | ||
bool isValid() { return !IsInvalid; } | ||
|
@@ -1108,13 +1165,39 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { | |
return isValid(); | ||
} | ||
|
||
bool handlePointerType(FieldDecl *FD, QualType FieldTy) final { | ||
// TODO: Replace with a better diagnostic. | ||
if (UnionCount) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I considered making this a separate handler that only overloaded enter/leave Union and the handleSyclAccessorType/handleSyclPointerType. We might also consider still doing this, and putting it in after the normal FieldChecker. Note the bad error messages. |
||
IsInvalid = true; | ||
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) | ||
<< FieldTy; | ||
} | ||
return isValid(); | ||
} | ||
|
||
bool handleSyclAccessorType(const CXXBaseSpecifier &BS, | ||
QualType FieldTy) final { | ||
// TODO: Replace with a better diagnostic. | ||
if (UnionCount) { | ||
IsInvalid = true; | ||
Diag.Report(BS.getBeginLoc(), diag::err_bad_kernel_param_type) << FieldTy; | ||
return isValid(); | ||
} | ||
|
||
checkAccessorType(FieldTy, BS.getBeginLoc()); | ||
return isValid(); | ||
} | ||
|
||
bool handleSyclAccessorType(FieldDecl *FD, QualType FieldTy) final { | ||
// TODO: Replace with a better diagnostic. | ||
// TODO: What other types do we need ot check? What types other than | ||
// pointer/accessor requires a decomposition? | ||
Comment on lines
+1193
to
+1194
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to check sampler and stream types. Union does not work with them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool, thanks! I'm also considering breaking the union checking into a separate handler, so it would need to handle all of these. |
||
if (UnionCount) { | ||
IsInvalid = true; | ||
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) << FieldTy; | ||
return isValid(); | ||
} | ||
|
||
checkAccessorType(FieldTy, FD->getLocation()); | ||
return isValid(); | ||
} | ||
|
@@ -1129,6 +1212,15 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { | |
IsInvalid = true; | ||
return isValid(); | ||
} | ||
|
||
bool enterUnion(const CXXRecordDecl *RD, FieldDecl *FD) { | ||
++UnionCount; | ||
return true; | ||
} | ||
bool leaveUnion(const CXXRecordDecl *RD, FieldDecl *FD) { | ||
--UnionCount; | ||
return true; | ||
} | ||
}; | ||
|
||
// A type to Create and own the FunctionDecl for the kernel. | ||
|
@@ -1291,6 +1383,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler { | |
return true; | ||
} | ||
|
||
bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { | ||
return handleScalarType(FD, FieldTy); | ||
} | ||
|
||
bool handleSyclHalfType(FieldDecl *FD, QualType FieldTy) final { | ||
addParam(FD, FieldTy); | ||
return true; | ||
|
@@ -1626,6 +1722,11 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { | |
return true; | ||
} | ||
|
||
bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { | ||
return handleScalarType(FD, FieldTy); | ||
} | ||
|
||
|
||
bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final { | ||
CXXCastPath BasePath; | ||
QualType DerivedTy(RD->getTypeForDecl(), 0); | ||
|
@@ -1830,6 +1931,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { | |
return true; | ||
} | ||
|
||
bool handleUnionType(FieldDecl *FD, QualType FieldTy) final { | ||
return handleScalarType(FD, FieldTy); | ||
} | ||
|
||
bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final { | ||
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout); | ||
return true; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'VisitUnion' function is the main trickery. See this to play around with it: (https://godbolt.org/z/z8b61d) .
There are 3 versions of the function. This first is the 'base' case.
The list of 'handlers' not specified via <> syntax (as we do when they aren't specified) are deduced like normal. This ends up in the normal 'Handlers' type pack, with the first of that pack being picked up by the CurHandler.
the return type is done with std::enable_if, which switches based on whether we should VisitUnionBody for the 'first' handler (the one in CurHandler).
If we SHOULD visit it, we call VisitUnion again, this time with the CurHandler put into the "FilteredHandlers" pack (which will never be deduced, since it is not the 'last' thing.
If we SHOULDN'T visit it, we call VIsitUnion again, this time with the CurHandler discarded from the list, and forward on the"FilteredHandlers" (which need to be specified inside <...>, since they cannot be deduced).
If there are no more items to be put into the cur_handler type (or deduced as a Handlers type), the base is called, which does the work.