@@ -686,22 +686,47 @@ static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent,
686
686
CXXRecordDecl *Wrapper,
687
687
Handlers &... handlers);
688
688
689
+ template <typename RangeTy, typename ... Handlers>
690
+ static void VisitField (CXXRecordDecl *Owner, RangeTy Item, QualType ItemTy,
691
+ Handlers &... handlers) {
692
+ if (Util::isSyclAccessorType (ItemTy)) {
693
+ (void )std::initializer_list<int >{
694
+ (handlers.handleSyclAccessorType (Item, ItemTy), 0 )...};
695
+ } else if (Util::isSyclStreamType (ItemTy))
696
+ (void )std::initializer_list<int >{
697
+ (handlers.handleSyclStreamType (Item, ItemTy), 0 )...};
698
+ else {
699
+ if (ItemTy->isArrayType ()) {
700
+ VisitArrayElements (Item, ItemTy, handlers...);
701
+ } else if (ItemTy->isStructureOrClassType ()) {
702
+ VisitAccessorWrapper (Owner, Item, ItemTy->getAsCXXRecordDecl (),
703
+ handlers...);
704
+ }
705
+ }
706
+ }
707
+
708
+ template <typename RangeTy, typename ... Handlers>
709
+ static void VisitArrayElements (RangeTy Item, QualType FieldTy,
710
+ Handlers &... handlers) {
711
+ const ConstantArrayType *CAT = cast<ConstantArrayType>(FieldTy);
712
+ QualType ET = CAT->getElementType ();
713
+ int64_t ElemCount = CAT->getSize ().getSExtValue ();
714
+ std::initializer_list<int >{(handlers.enterArray (), 0 )...};
715
+ for (int64_t Count = 0 ; Count < ElemCount; Count++) {
716
+ VisitField (nullptr , Item, ET, handlers...);
717
+ (void )std::initializer_list<int >{(handlers.nextElement (ET), 0 )...};
718
+ }
719
+ (void )std::initializer_list<int >{(handlers.leaveArray (ET, ElemCount), 0 )...};
720
+ }
721
+
689
722
template <typename RangeTy, typename ... Handlers>
690
723
static void VisitAccessorWrapperHelper (CXXRecordDecl *Owner, RangeTy Range,
691
724
Handlers &... handlers) {
692
725
for (const auto &Item : Range) {
693
726
QualType ItemTy = getItemType (Item);
694
- if (Util::isSyclAccessorType (ItemTy))
695
- (void )std::initializer_list<int >{
696
- (handlers.handleSyclAccessorType (Item, ItemTy), 0 )...};
697
- else if (Util::isSyclStreamType (ItemTy)) {
698
- VisitAccessorWrapper (Owner, Item, ItemTy->getAsCXXRecordDecl (),
699
- handlers...);
700
- (void )std::initializer_list<int >{
701
- (handlers.handleSyclStreamType (Item, ItemTy), 0 )...};
702
- } else if (ItemTy->isStructureOrClassType ())
703
- VisitAccessorWrapper (Owner, Item, ItemTy->getAsCXXRecordDecl (),
704
- handlers...);
727
+ (void )std::initializer_list<int >{(handlers.enterField (Owner, Item), 0 )...};
728
+ VisitField (Owner, Item, ItemTy, handlers...);
729
+ (void )std::initializer_list<int >{(handlers.leaveField (Owner, Item), 0 )...};
705
730
}
706
731
}
707
732
@@ -728,6 +753,8 @@ static void VisitRecordFields(RecordDecl::field_range Fields,
728
753
(void )std::initializer_list<int > { (handlers.FUNC (Field, FieldTy), 0 )... }
729
754
730
755
for (const auto &Field : Fields) {
756
+ (void )std::initializer_list<int >{
757
+ (handlers.enterField (nullptr , Field), 0 )...};
731
758
QualType FieldTy = Field->getType ();
732
759
733
760
if (Util::isSyclAccessorType (FieldTy))
@@ -749,12 +776,15 @@ static void VisitRecordFields(RecordDecl::field_range Fields,
749
776
KF_FOR_EACH (handleReferenceType);
750
777
else if (FieldTy->isPointerType ())
751
778
KF_FOR_EACH (handlePointerType);
752
- else if (FieldTy->isArrayType ())
779
+ else if (FieldTy->isArrayType ()) {
753
780
KF_FOR_EACH (handleArrayType);
754
- else if (FieldTy->isScalarType ())
781
+ VisitArrayElements (Field, FieldTy, handlers...);
782
+ } else if (FieldTy->isScalarType ())
755
783
KF_FOR_EACH (handleScalarType);
756
784
else
757
785
KF_FOR_EACH (handleOtherType);
786
+ (void )std::initializer_list<int >{
787
+ (handlers.leaveField (nullptr , Field), 0 )...};
758
788
}
759
789
#undef KF_FOR_EACH
760
790
}
@@ -780,6 +810,7 @@ template <typename Derived> class SyclKernelFieldHandler {
780
810
virtual void handleStructType (FieldDecl *, QualType) {}
781
811
virtual void handleReferenceType (FieldDecl *, QualType) {}
782
812
virtual void handlePointerType (FieldDecl *, QualType) {}
813
+ virtual void handleArrayType (const CXXBaseSpecifier &, QualType) {}
783
814
virtual void handleArrayType (FieldDecl *, QualType) {}
784
815
virtual void handleScalarType (FieldDecl *, QualType) {}
785
816
// Most handlers shouldn't be handling this, just the field checker.
@@ -793,6 +824,17 @@ template <typename Derived> class SyclKernelFieldHandler {
793
824
virtual void leaveStruct (const CXXRecordDecl *, FieldDecl *) {}
794
825
virtual void enterStruct (const CXXRecordDecl *, const CXXBaseSpecifier &) {}
795
826
virtual void leaveStruct (const CXXRecordDecl *, const CXXBaseSpecifier &) {}
827
+
828
+ // The following are used for stepping through array elements.
829
+
830
+ virtual void enterField (const CXXRecordDecl *, const CXXBaseSpecifier &) {}
831
+ virtual void leaveField (const CXXRecordDecl *, const CXXBaseSpecifier &) {}
832
+ virtual void enterField (const CXXRecordDecl *, FieldDecl *) {}
833
+ virtual void leaveField (const CXXRecordDecl *, FieldDecl *) {}
834
+ virtual void enterArray (const CXXBaseSpecifier &) {}
835
+ virtual void enterArray () {}
836
+ virtual void nextElement (QualType) {}
837
+ virtual void leaveArray (QualType, int64_t ) {}
796
838
};
797
839
798
840
// A type to check the validity of all of the argument types.
@@ -801,6 +843,43 @@ class SyclKernelFieldChecker
801
843
bool IsInvalid = false ;
802
844
DiagnosticsEngine &Diag;
803
845
846
+ // Check whether the object is bit-wise copyable
847
+ bool copyableToKernel (const FieldDecl *FD, const QualType &FieldTy) {
848
+ // C++ lambda capture already flags non-constant array types.
849
+ // Here, we check copyability.
850
+ if (FieldTy->isConstantArrayType ()) {
851
+ const ConstantArrayType *CAT = cast<ConstantArrayType>(FieldTy);
852
+ QualType ET = CAT->getElementType ();
853
+ return copyableToKernel (FD, ET);
854
+ }
855
+ if (SemaRef.getASTContext ().getLangOpts ().SYCLStdLayoutKernelParams ) {
856
+ if (!FieldTy->isStandardLayoutType ()) {
857
+ SemaRef.getASTContext ().getDiagnostics ().Report (
858
+ FD->getLocation (), diag::err_sycl_non_std_layout_type)
859
+ << FieldTy;
860
+ return false ;
861
+ }
862
+ }
863
+ if (!FieldTy->isStructureOrClassType ()) {
864
+ return true ;
865
+ }
866
+ CXXRecordDecl *RD =
867
+ cast<CXXRecordDecl>(FieldTy->getAs <RecordType>()->getDecl ());
868
+ if (!RD->hasTrivialCopyConstructor ()) {
869
+ SemaRef.getASTContext ().getDiagnostics ().Report (
870
+ FD->getLocation (), diag::err_sycl_non_trivially_copy_ctor_dtor_type)
871
+ << 0 << FieldTy;
872
+ return false ;
873
+ }
874
+ if (!RD->hasTrivialDestructor ()) {
875
+ SemaRef.getASTContext ().getDiagnostics ().Report (
876
+ FD->getLocation (), diag::err_sycl_non_trivially_copy_ctor_dtor_type)
877
+ << 1 << FieldTy;
878
+ return false ;
879
+ }
880
+ return true ;
881
+ }
882
+
804
883
public:
805
884
SyclKernelFieldChecker (Sema &S)
806
885
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
@@ -810,33 +889,17 @@ class SyclKernelFieldChecker
810
889
IsInvalid = Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
811
890
<< FieldTy;
812
891
}
892
+
813
893
void handleStructType (FieldDecl *FD, QualType FieldTy) final {
814
- if (SemaRef.getASTContext ().getLangOpts ().SYCLStdLayoutKernelParams &&
815
- !FieldTy->isStandardLayoutType ())
816
- IsInvalid =
817
- Diag.Report (FD->getLocation (), diag::err_sycl_non_std_layout_type)
818
- << FieldTy;
819
- else {
820
- CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl ();
821
- if (!RD->hasTrivialCopyConstructor ())
822
-
823
- IsInvalid =
824
- Diag.Report (FD->getLocation (),
825
- diag::err_sycl_non_trivially_copy_ctor_dtor_type)
826
- << 0 << FieldTy;
827
- else if (!RD->hasTrivialDestructor ())
828
- IsInvalid =
829
- Diag.Report (FD->getLocation (),
830
- diag::err_sycl_non_trivially_copy_ctor_dtor_type)
831
- << 1 << FieldTy;
832
- }
894
+ IsInvalid = !copyableToKernel (FD, FieldTy);
895
+ }
896
+
897
+ void handleArrayType (const CXXBaseSpecifier &, QualType) final {
898
+ // FIXME
833
899
}
834
900
835
- // We should be able to handle this, so we made it part of the visitor, but
836
- // this is 'to be implemented'.
837
901
void handleArrayType (FieldDecl *FD, QualType FieldTy) final {
838
- IsInvalid = Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
839
- << FieldTy;
902
+ IsInvalid = !copyableToKernel (FD, FieldTy);
840
903
}
841
904
842
905
void handleOtherType (FieldDecl *FD, QualType FieldTy) final {
@@ -1437,20 +1500,23 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
1437
1500
: CalculatedName);
1438
1501
1439
1502
SyclKernelFieldChecker checker (*this );
1440
- SyclKernelDeclCreator kernel_decl (*this , checker, KernelName,
1441
- KernelLambda->getLocation (),
1442
- KernelCallerFunc->isInlined ());
1443
- SyclKernelBodyCreator kernel_body (*this , kernel_decl, KernelLambda,
1444
- KernelCallerFunc);
1445
- SyclKernelIntHeaderCreator int_header (
1446
- *this , getSyclIntegrationHeader (), KernelLambda,
1447
- calculateKernelNameType (Context, KernelCallerFunc), KernelName,
1448
- StableName);
1449
-
1450
- ConstructingOpenCLKernel = true ;
1451
- VisitRecordFields (KernelLambda->fields (), checker, kernel_decl, kernel_body,
1452
- int_header);
1453
- ConstructingOpenCLKernel = false ;
1503
+ VisitRecordFields (KernelLambda->fields (), checker);
1504
+ if (checker.isValid ()) {
1505
+ SyclKernelDeclCreator kernel_decl (*this , checker, KernelName,
1506
+ KernelLambda->getLocation (),
1507
+ KernelCallerFunc->isInlined ());
1508
+ SyclKernelBodyCreator kernel_body (*this , kernel_decl, KernelLambda,
1509
+ KernelCallerFunc);
1510
+ SyclKernelIntHeaderCreator int_header (
1511
+ *this , getSyclIntegrationHeader (), KernelLambda,
1512
+ calculateKernelNameType (Context, KernelCallerFunc), KernelName,
1513
+ StableName);
1514
+
1515
+ ConstructingOpenCLKernel = true ;
1516
+ VisitRecordFields (KernelLambda->fields (), kernel_decl, kernel_body,
1517
+ int_header);
1518
+ ConstructingOpenCLKernel = false ;
1519
+ }
1454
1520
}
1455
1521
1456
1522
void Sema::MarkDevice (void ) {
0 commit comments