Skip to content

[NOT FOR COMMIT] First try at getting unions supported, does type checking but does not #2270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 106 additions & 1 deletion clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,13 @@ class KernelObjVisitor {
else if (ElementTy->isStructureOrClassType())
VisitRecord(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
handlers...);
else if (ElementTy->isUnionType())
// TODO: This check is still necessary I think?! Array seems to handle
// this differently (see above) for structs I think.
//if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
VisitUnion(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
handlers...);
//}
else if (ElementTy->isArrayType())
VisitArrayElements(ArrayField, ElementTy, handlers...);
else if (ElementTy->isScalarType())
Expand Down Expand Up @@ -849,6 +856,41 @@ class KernelObjVisitor {
void VisitRecord(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper, Handlers &... handlers);

// Base case, only calls these when filtered.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'VisitUnion' function is the main trickery. See this to play around with it: (https://godbolt.org/z/z8b61d) .

There are 3 versions of the function. This first is the 'base' case.

The list of 'handlers' not specified via <> syntax (as we do when they aren't specified) are deduced like normal. This ends up in the normal 'Handlers' type pack, with the first of that pack being picked up by the CurHandler.

the return type is done with std::enable_if, which switches based on whether we should VisitUnionBody for the 'first' handler (the one in CurHandler).

If we SHOULD visit it, we call VisitUnion again, this time with the CurHandler put into the "FilteredHandlers" pack (which will never be deduced, since it is not the 'last' thing.
If we SHOULDN'T visit it, we call VIsitUnion again, this time with the CurHandler discarded from the list, and forward on the"FilteredHandlers" (which need to be specified inside <...>, since they cannot be deduced).

If there are no more items to be put into the cur_handler type (or deduced as a Handlers type), the base is called, which does the work.

template <typename... FilteredHandlers, typename ParentTy>
void VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper,
FilteredHandlers &... handlers) {
(void)std::initializer_list<int>{
(handlers.enterUnion(Owner, Parent), 0)...};
VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...);
(void)std::initializer_list<int>{
(handlers.leaveUnion(Owner, Parent), 0)...};
}


template <typename... FilteredHandlers, typename ParentTy,
typename CurHandler, typename... Handlers>
std::enable_if_t<!CurHandler::VisitUnionBody>
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper,
FilteredHandlers &... filtered_handlers,
CurHandler &cur_handler, Handlers &... handlers) {
VisitUnion<FilteredHandlers...>(
Owner, Parent, Wrapper, filtered_handlers..., handlers...);
}

template <typename... FilteredHandlers, typename ParentTy,
typename CurHandler, typename... Handlers>
std::enable_if_t<CurHandler::VisitUnionBody>
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper,
FilteredHandlers &... filtered_handlers,
CurHandler &cur_handler, Handlers &... handlers) {
VisitUnion<FilteredHandlers..., CurHandler>(
Owner, Parent, Wrapper, filtered_handlers..., cur_handler, handlers...);
}

template <typename... Handlers>
void VisitRecordHelper(const CXXRecordDecl *Owner,
clang::CXXRecordDecl::base_class_const_range Range,
Expand Down Expand Up @@ -934,6 +976,11 @@ class KernelObjVisitor {
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
VisitRecord(Owner, Field, RD, handlers...);
}
} else if (FieldTy->isUnionType()) {
if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
VisitUnion(Owner, Field, RD, handlers...);
}
} else if (FieldTy->isReferenceType())
KF_FOR_EACH(handleReferenceType, Field, FieldTy);
else if (FieldTy->isPointerType())
Expand All @@ -949,6 +996,9 @@ class KernelObjVisitor {
(handlers.leaveField(Owner, Field), 0)...};
}
}



#undef KF_FOR_EACH
};
// Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter
Expand All @@ -973,6 +1023,8 @@ class SyclKernelFieldHandler {
SyclKernelFieldHandler(Sema &S) : SemaRef(S) {}

public:

static const constexpr bool VisitUnionBody = false;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default opt-out here.

// Mark these virtual so that we can use override in the implementer classes,
// despite virtual dispatch never being used.

Expand All @@ -997,6 +1049,7 @@ class SyclKernelFieldHandler {
}
virtual bool handleSyclHalfType(FieldDecl *, QualType) { return true; }
virtual bool handleStructType(FieldDecl *, QualType) { return true; }
virtual bool handleUnionType(FieldDecl *, QualType) { return true; }
virtual bool handleReferenceType(FieldDecl *, QualType) { return true; }
virtual bool handlePointerType(FieldDecl *, QualType) { return true; }
virtual bool handleArrayType(FieldDecl *, QualType) { return true; }
Expand All @@ -1016,6 +1069,8 @@ class SyclKernelFieldHandler {
virtual bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) {
return true;
}
virtual bool enterUnion(const CXXRecordDecl *, FieldDecl *) { return true; }
virtual bool leaveUnion(const CXXRecordDecl *, FieldDecl *) { return true; }

// The following are used for stepping through array elements.

Expand All @@ -1037,6 +1092,7 @@ class SyclKernelFieldHandler {
// A type to check the validity of all of the argument types.
class SyclKernelFieldChecker : public SyclKernelFieldHandler {
bool IsInvalid = false;
unsigned UnionCount = 0;
DiagnosticsEngine &Diag;

// Check whether the object should be disallowed from being copied to kernel.
Expand Down Expand Up @@ -1079,7 +1135,6 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
void checkAccessorType(QualType Ty, SourceRange Loc) {
assert(Util::isSyclAccessorType(Ty) &&
"Should only be called on SYCL accessor types.");

const RecordDecl *RecD = Ty->getAsRecordDecl();
if (const ClassTemplateSpecializationDecl *CTSD =
dyn_cast<ClassTemplateSpecializationDecl>(RecD)) {
Expand All @@ -1093,6 +1148,8 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
}

public:
static const constexpr bool VisitUnionBody = true;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how a handler opts-into this.


SyclKernelFieldChecker(Sema &S)
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
bool isValid() { return !IsInvalid; }
Expand All @@ -1108,13 +1165,39 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
return isValid();
}

bool handlePointerType(FieldDecl *FD, QualType FieldTy) final {
// TODO: Replace with a better diagnostic.
if (UnionCount) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered making this a separate handler that only overloaded enter/leave Union and the handleSyclAccessorType/handleSyclPointerType.

We might also consider still doing this, and putting it in after the normal FieldChecker. Note the bad error messages.

IsInvalid = true;
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
<< FieldTy;
}
return isValid();
}

bool handleSyclAccessorType(const CXXBaseSpecifier &BS,
QualType FieldTy) final {
// TODO: Replace with a better diagnostic.
if (UnionCount) {
IsInvalid = true;
Diag.Report(BS.getBeginLoc(), diag::err_bad_kernel_param_type) << FieldTy;
return isValid();
}

checkAccessorType(FieldTy, BS.getBeginLoc());
return isValid();
}

bool handleSyclAccessorType(FieldDecl *FD, QualType FieldTy) final {
// TODO: Replace with a better diagnostic.
// TODO: What other types do we need ot check? What types other than
// pointer/accessor requires a decomposition?
Comment on lines +1193 to +1194
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to check sampler and stream types. Union does not work with them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks! I'm also considering breaking the union checking into a separate handler, so it would need to handle all of these.

if (UnionCount) {
IsInvalid = true;
Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) << FieldTy;
return isValid();
}

checkAccessorType(FieldTy, FD->getLocation());
return isValid();
}
Expand All @@ -1129,6 +1212,15 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
IsInvalid = true;
return isValid();
}

bool enterUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
++UnionCount;
return true;
}
bool leaveUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
--UnionCount;
return true;
}
};

// A type to Create and own the FunctionDecl for the kernel.
Expand Down Expand Up @@ -1291,6 +1383,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
return true;
}

bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
return handleScalarType(FD, FieldTy);
}

bool handleSyclHalfType(FieldDecl *FD, QualType FieldTy) final {
addParam(FD, FieldTy);
return true;
Expand Down Expand Up @@ -1626,6 +1722,11 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
return true;
}

bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
return handleScalarType(FD, FieldTy);
}


bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
CXXCastPath BasePath;
QualType DerivedTy(RD->getTypeForDecl(), 0);
Expand Down Expand Up @@ -1830,6 +1931,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
return true;
}

bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
return handleScalarType(FD, FieldTy);
}

bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final {
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
return true;
Expand Down