Skip to content

Commit 95ea801

Browse files
committed
[SYCL] Changes to visitor model in preparation of array support.
Signed-off-by: rdeodhar <[email protected]>
1 parent 7146426 commit 95ea801

File tree

1 file changed

+116
-50
lines changed

1 file changed

+116
-50
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 116 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -686,22 +686,47 @@ static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent,
686686
CXXRecordDecl *Wrapper,
687687
Handlers &... handlers);
688688

689+
template <typename RangeTy, typename... Handlers>
690+
static void VisitField(CXXRecordDecl *Owner, RangeTy Item, QualType ItemTy,
691+
Handlers &... handlers) {
692+
if (Util::isSyclAccessorType(ItemTy)) {
693+
(void)std::initializer_list<int>{
694+
(handlers.handleSyclAccessorType(Item, ItemTy), 0)...};
695+
} else if (Util::isSyclStreamType(ItemTy))
696+
(void)std::initializer_list<int>{
697+
(handlers.handleSyclStreamType(Item, ItemTy), 0)...};
698+
else {
699+
if (ItemTy->isArrayType()) {
700+
VisitArrayElements(Item, ItemTy, handlers...);
701+
} else if (ItemTy->isStructureOrClassType()) {
702+
VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(),
703+
handlers...);
704+
}
705+
}
706+
}
707+
708+
template <typename RangeTy, typename... Handlers>
709+
static void VisitArrayElements(RangeTy Item, QualType FieldTy,
710+
Handlers &... handlers) {
711+
const ConstantArrayType *CAT = cast<ConstantArrayType>(FieldTy);
712+
QualType ET = CAT->getElementType();
713+
int64_t ElemCount = CAT->getSize().getSExtValue();
714+
std::initializer_list<int>{(handlers.enterArray(), 0)...};
715+
for (int64_t Count = 0; Count < ElemCount; Count++) {
716+
VisitField(nullptr, Item, ET, handlers...);
717+
(void)std::initializer_list<int>{(handlers.nextElement(ET), 0)...};
718+
}
719+
(void)std::initializer_list<int>{(handlers.leaveArray(ET, ElemCount), 0)...};
720+
}
721+
689722
template <typename RangeTy, typename... Handlers>
690723
static void VisitAccessorWrapperHelper(CXXRecordDecl *Owner, RangeTy Range,
691724
Handlers &... handlers) {
692725
for (const auto &Item : Range) {
693726
QualType ItemTy = getItemType(Item);
694-
if (Util::isSyclAccessorType(ItemTy))
695-
(void)std::initializer_list<int>{
696-
(handlers.handleSyclAccessorType(Item, ItemTy), 0)...};
697-
else if (Util::isSyclStreamType(ItemTy)) {
698-
VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(),
699-
handlers...);
700-
(void)std::initializer_list<int>{
701-
(handlers.handleSyclStreamType(Item, ItemTy), 0)...};
702-
} else if (ItemTy->isStructureOrClassType())
703-
VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(),
704-
handlers...);
727+
(void)std::initializer_list<int>{(handlers.enterField(Owner, Item), 0)...};
728+
VisitField(Owner, Item, ItemTy, handlers...);
729+
(void)std::initializer_list<int>{(handlers.leaveField(Owner, Item), 0)...};
705730
}
706731
}
707732

@@ -728,6 +753,8 @@ static void VisitRecordFields(RecordDecl::field_range Fields,
728753
(void)std::initializer_list<int> { (handlers.FUNC(Field, FieldTy), 0)... }
729754

730755
for (const auto &Field : Fields) {
756+
(void)std::initializer_list<int>{
757+
(handlers.enterField(nullptr, Field), 0)...};
731758
QualType FieldTy = Field->getType();
732759

733760
if (Util::isSyclAccessorType(FieldTy))
@@ -749,12 +776,15 @@ static void VisitRecordFields(RecordDecl::field_range Fields,
749776
KF_FOR_EACH(handleReferenceType);
750777
else if (FieldTy->isPointerType())
751778
KF_FOR_EACH(handlePointerType);
752-
else if (FieldTy->isArrayType())
779+
else if (FieldTy->isArrayType()) {
753780
KF_FOR_EACH(handleArrayType);
754-
else if (FieldTy->isScalarType())
781+
VisitArrayElements(Field, FieldTy, handlers...);
782+
} else if (FieldTy->isScalarType())
755783
KF_FOR_EACH(handleScalarType);
756784
else
757785
KF_FOR_EACH(handleOtherType);
786+
(void)std::initializer_list<int>{
787+
(handlers.leaveField(nullptr, Field), 0)...};
758788
}
759789
#undef KF_FOR_EACH
760790
}
@@ -780,6 +810,7 @@ template <typename Derived> class SyclKernelFieldHandler {
780810
virtual void handleStructType(FieldDecl *, QualType) {}
781811
virtual void handleReferenceType(FieldDecl *, QualType) {}
782812
virtual void handlePointerType(FieldDecl *, QualType) {}
813+
virtual void handleArrayType(const CXXBaseSpecifier &, QualType) {}
783814
virtual void handleArrayType(FieldDecl *, QualType) {}
784815
virtual void handleScalarType(FieldDecl *, QualType) {}
785816
// Most handlers shouldn't be handling this, just the field checker.
@@ -793,6 +824,17 @@ template <typename Derived> class SyclKernelFieldHandler {
793824
virtual void leaveStruct(const CXXRecordDecl *, FieldDecl *) {}
794825
virtual void enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) {}
795826
virtual void leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) {}
827+
828+
// The following are used for stepping through array elements.
829+
830+
virtual void enterField(const CXXRecordDecl *, const CXXBaseSpecifier &) {}
831+
virtual void leaveField(const CXXRecordDecl *, const CXXBaseSpecifier &) {}
832+
virtual void enterField(const CXXRecordDecl *, FieldDecl *) {}
833+
virtual void leaveField(const CXXRecordDecl *, FieldDecl *) {}
834+
virtual void enterArray(const CXXBaseSpecifier &) {}
835+
virtual void enterArray() {}
836+
virtual void nextElement(QualType) {}
837+
virtual void leaveArray(QualType, int64_t) {}
796838
};
797839

798840
// A type to check the validity of all of the argument types.
@@ -801,6 +843,43 @@ class SyclKernelFieldChecker
801843
bool IsInvalid = false;
802844
DiagnosticsEngine &Diag;
803845

846+
// Check whether the object is bit-wise copyable
847+
bool copyableToKernel(const FieldDecl *FD, const QualType &FieldTy) {
848+
// C++ lambda capture already flags non-constant array types.
849+
// Here, we check copyability.
850+
if (FieldTy->isConstantArrayType()) {
851+
const ConstantArrayType *CAT = cast<ConstantArrayType>(FieldTy);
852+
QualType ET = CAT->getElementType();
853+
return copyableToKernel(FD, ET);
854+
}
855+
if (SemaRef.getASTContext().getLangOpts().SYCLStdLayoutKernelParams) {
856+
if (!FieldTy->isStandardLayoutType()) {
857+
SemaRef.getASTContext().getDiagnostics().Report(
858+
FD->getLocation(), diag::err_sycl_non_std_layout_type)
859+
<< FieldTy;
860+
return false;
861+
}
862+
}
863+
if (!FieldTy->isStructureOrClassType()) {
864+
return true;
865+
}
866+
CXXRecordDecl *RD =
867+
cast<CXXRecordDecl>(FieldTy->getAs<RecordType>()->getDecl());
868+
if (!RD->hasTrivialCopyConstructor()) {
869+
SemaRef.getASTContext().getDiagnostics().Report(
870+
FD->getLocation(), diag::err_sycl_non_trivially_copy_ctor_dtor_type)
871+
<< 0 << FieldTy;
872+
return false;
873+
}
874+
if (!RD->hasTrivialDestructor()) {
875+
SemaRef.getASTContext().getDiagnostics().Report(
876+
FD->getLocation(), diag::err_sycl_non_trivially_copy_ctor_dtor_type)
877+
<< 1 << FieldTy;
878+
return false;
879+
}
880+
return true;
881+
}
882+
804883
public:
805884
SyclKernelFieldChecker(Sema &S)
806885
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
@@ -810,33 +889,17 @@ class SyclKernelFieldChecker
810889
IsInvalid = Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
811890
<< FieldTy;
812891
}
892+
813893
void handleStructType(FieldDecl *FD, QualType FieldTy) final {
814-
if (SemaRef.getASTContext().getLangOpts().SYCLStdLayoutKernelParams &&
815-
!FieldTy->isStandardLayoutType())
816-
IsInvalid =
817-
Diag.Report(FD->getLocation(), diag::err_sycl_non_std_layout_type)
818-
<< FieldTy;
819-
else {
820-
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
821-
if (!RD->hasTrivialCopyConstructor())
822-
823-
IsInvalid =
824-
Diag.Report(FD->getLocation(),
825-
diag::err_sycl_non_trivially_copy_ctor_dtor_type)
826-
<< 0 << FieldTy;
827-
else if (!RD->hasTrivialDestructor())
828-
IsInvalid =
829-
Diag.Report(FD->getLocation(),
830-
diag::err_sycl_non_trivially_copy_ctor_dtor_type)
831-
<< 1 << FieldTy;
832-
}
894+
IsInvalid = !copyableToKernel(FD, FieldTy);
895+
}
896+
897+
void handleArrayType(const CXXBaseSpecifier &, QualType) final {
898+
// FIXME
833899
}
834900

835-
// We should be able to handle this, so we made it part of the visitor, but
836-
// this is 'to be implemented'.
837901
void handleArrayType(FieldDecl *FD, QualType FieldTy) final {
838-
IsInvalid = Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type)
839-
<< FieldTy;
902+
IsInvalid = !copyableToKernel(FD, FieldTy);
840903
}
841904

842905
void handleOtherType(FieldDecl *FD, QualType FieldTy) final {
@@ -1437,20 +1500,23 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
14371500
: CalculatedName);
14381501

14391502
SyclKernelFieldChecker checker(*this);
1440-
SyclKernelDeclCreator kernel_decl(*this, checker, KernelName,
1441-
KernelLambda->getLocation(),
1442-
KernelCallerFunc->isInlined());
1443-
SyclKernelBodyCreator kernel_body(*this, kernel_decl, KernelLambda,
1444-
KernelCallerFunc);
1445-
SyclKernelIntHeaderCreator int_header(
1446-
*this, getSyclIntegrationHeader(), KernelLambda,
1447-
calculateKernelNameType(Context, KernelCallerFunc), KernelName,
1448-
StableName);
1449-
1450-
ConstructingOpenCLKernel = true;
1451-
VisitRecordFields(KernelLambda->fields(), checker, kernel_decl, kernel_body,
1452-
int_header);
1453-
ConstructingOpenCLKernel = false;
1503+
VisitRecordFields(KernelLambda->fields(), checker);
1504+
if (checker.isValid()) {
1505+
SyclKernelDeclCreator kernel_decl(*this, checker, KernelName,
1506+
KernelLambda->getLocation(),
1507+
KernelCallerFunc->isInlined());
1508+
SyclKernelBodyCreator kernel_body(*this, kernel_decl, KernelLambda,
1509+
KernelCallerFunc);
1510+
SyclKernelIntHeaderCreator int_header(
1511+
*this, getSyclIntegrationHeader(), KernelLambda,
1512+
calculateKernelNameType(Context, KernelCallerFunc), KernelName,
1513+
StableName);
1514+
1515+
ConstructingOpenCLKernel = true;
1516+
VisitRecordFields(KernelLambda->fields(), kernel_decl, kernel_body,
1517+
int_header);
1518+
ConstructingOpenCLKernel = false;
1519+
}
14541520
}
14551521

14561522
void Sema::MarkDevice(void) {

0 commit comments

Comments
 (0)