@@ -454,9 +454,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
454
454
return Res;
455
455
};
456
456
457
- QualType FieldType = Field-> getType ();
458
- CXXRecordDecl *CRD = FieldType-> getAsCXXRecordDecl ();
459
- if ( CRD && Util::isSyclAccessorType (FieldType) ) {
457
+ auto getExprForAccessorInit = [&]( const QualType ¶mTy,
458
+ FieldDecl *Field,
459
+ const CXXRecordDecl * CRD, Expr *Base ) {
460
460
// Since this is an accessor next 4 TargetFuncParams including current
461
461
// should be set in __init method: _ValueType*, range<int>, range<int>,
462
462
// id<int>
@@ -472,9 +472,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
472
472
std::advance (TargetFuncParam, NumParams - 1 );
473
473
474
474
DeclAccessPair FieldDAP = DeclAccessPair::make (Field, AS_none);
475
- // kernel_obj .accessor
475
+ // [kenrel_obj or wrapper object] .accessor
476
476
auto AccessorME = MemberExpr::Create (
477
- S.Context , CloneRef , false , SourceLocation (),
477
+ S.Context , Base , false , SourceLocation (),
478
478
NestedNameSpecifierLoc (), SourceLocation (), Field, FieldDAP,
479
479
DeclarationNameInfo (Field->getDeclName (), SourceLocation ()),
480
480
nullptr , Field->getType (), VK_LValue, OK_Ordinary);
@@ -488,7 +488,7 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
488
488
}
489
489
assert (InitMethod && " The accessor must have the __init method" );
490
490
491
- // kernel_obj .accessor.__init
491
+ // [kenrel_obj or wrapper object] .accessor.__init
492
492
DeclAccessPair MethodDAP = DeclAccessPair::make (InitMethod, AS_none);
493
493
auto ME = MemberExpr::Create (
494
494
S.Context , AccessorME, false , SourceLocation (),
@@ -515,11 +515,52 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
515
515
S, ((*ParamItr++))->getOriginalType (), ParamDREs[2 ]));
516
516
ParamStmts.push_back (getExprForRangeOrOffset (
517
517
S, ((*ParamItr++))->getOriginalType (), ParamDREs[3 ]));
518
- // kernel_obj .accessor.__init(_ValueType*, range<int>, range<int> ,
519
- // id<int>)
518
+ // [kenrel_obj or wrapper object] .accessor.__init(_ValueType*,
519
+ // range<int>, range<int>, id<int>)
520
520
CXXMemberCallExpr *Call = CXXMemberCallExpr::Create (
521
521
S.Context , ME, ParamStmts, ResultTy, VK, SourceLocation ());
522
522
BodyStmts.push_back (Call);
523
+ };
524
+
525
+ // Recursively search for accessor fields to initialize them with kernel
526
+ // parameters
527
+ std::function<void (const CXXRecordDecl *, Expr *)>
528
+ getExprForWrappedAccessorInit = [&](const CXXRecordDecl *CRD,
529
+ Expr *Base) {
530
+ for (auto *WrapperFld : CRD->fields ()) {
531
+ QualType FldType = WrapperFld->getType ();
532
+ CXXRecordDecl *WrapperFldCRD = FldType->getAsCXXRecordDecl ();
533
+ if (FldType->isStructureOrClassType ()) {
534
+ if (Util::isSyclAccessorType (FldType)) {
535
+ // Accessor field found - create expr to initialize this
536
+ // accessor object. Need to start from the next target
537
+ // function parameter, since current one is the wrapper object
538
+ // or parameter of the previous processed accessor object.
539
+ TargetFuncParam++;
540
+ getExprForAccessorInit (FldType, WrapperFld, WrapperFldCRD,
541
+ Base);
542
+ } else {
543
+ // Field is a structure or class so change the wrapper object
544
+ // and recursively search for accessor field.
545
+ DeclAccessPair WrapperFieldDAP =
546
+ DeclAccessPair::make (WrapperFld, AS_none);
547
+ auto NewBase = MemberExpr::Create (
548
+ S.Context , Base, false , SourceLocation (),
549
+ NestedNameSpecifierLoc (), SourceLocation (), WrapperFld,
550
+ WrapperFieldDAP,
551
+ DeclarationNameInfo (WrapperFld->getDeclName (),
552
+ SourceLocation ()),
553
+ nullptr , WrapperFld->getType (), VK_LValue, OK_Ordinary);
554
+ getExprForWrappedAccessorInit (WrapperFldCRD, NewBase);
555
+ }
556
+ }
557
+ }
558
+ };
559
+
560
+ QualType FieldType = Field->getType ();
561
+ CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl ();
562
+ if (Util::isSyclAccessorType (FieldType)) {
563
+ getExprForAccessorInit (FieldType, Field, CRD, CloneRef);
523
564
} else if (CRD && Util::isSyclSamplerType (FieldType)) {
524
565
525
566
// Sampler has only one TargetFuncParam, which should be set in
@@ -596,6 +637,12 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
596
637
BinaryOperator (Lhs, Rhs, BO_Assign, FieldType, VK_LValue,
597
638
OK_Ordinary, SourceLocation (), FPOptions ());
598
639
BodyStmts.push_back (Res);
640
+
641
+ // If a structure/class type has accessor fields then we need to
642
+ // initialize these accessors in proper way by calling __init method of
643
+ // the accessor and passing corresponding kernel parameters.
644
+ if (CRD)
645
+ getExprForWrappedAccessorInit (CRD, Lhs);
599
646
} else {
600
647
llvm_unreachable (" unsupported field type" );
601
648
}
@@ -675,56 +722,78 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
675
722
// create a parameter descriptor and append it to the result
676
723
ParamDescs.push_back (makeParamDesc (Fld, ArgType));
677
724
};
725
+
726
+ auto createAccessorParamDesc = [&](const FieldDecl *Fld,
727
+ const QualType &ArgTy) {
728
+ // the parameter is a SYCL accessor object
729
+ const auto *RecordDecl = ArgTy->getAsCXXRecordDecl ();
730
+ assert (RecordDecl && " accessor must be of a record type" );
731
+ const auto *TemplateDecl =
732
+ cast<ClassTemplateSpecializationDecl>(RecordDecl);
733
+ // First accessor template parameter - data type
734
+ QualType PointeeType = TemplateDecl->getTemplateArgs ()[0 ].getAsType ();
735
+ // Fourth parameter - access target
736
+ target AccessTarget = getAccessTarget (TemplateDecl);
737
+ Qualifiers Quals = PointeeType.getQualifiers ();
738
+ // TODO: Support all access targets
739
+ switch (AccessTarget) {
740
+ case target::global_buffer:
741
+ Quals.setAddressSpace (LangAS::opencl_global);
742
+ break ;
743
+ case target::constant_buffer:
744
+ Quals.setAddressSpace (LangAS::opencl_constant);
745
+ break ;
746
+ case target::local:
747
+ Quals.setAddressSpace (LangAS::opencl_local);
748
+ break ;
749
+ default :
750
+ llvm_unreachable (" Unsupported access target" );
751
+ }
752
+ PointeeType =
753
+ Context.getQualifiedType (PointeeType.getUnqualifiedType (), Quals);
754
+ QualType PointerType = Context.getPointerType (PointeeType);
755
+
756
+ CreateAndAddPrmDsc (Fld, PointerType);
757
+
758
+ FieldDecl *AccessRangeFld =
759
+ getFieldDeclByName (RecordDecl, {" impl" , " AccessRange" });
760
+ assert (AccessRangeFld &&
761
+ " The accessor.impl must contain the AccessRange field" );
762
+ CreateAndAddPrmDsc (AccessRangeFld, AccessRangeFld->getType ());
763
+
764
+ FieldDecl *MemRangeFld =
765
+ getFieldDeclByName (RecordDecl, {" impl" , " MemRange" });
766
+ assert (MemRangeFld && " The accessor.impl must contain the MemRange field" );
767
+ CreateAndAddPrmDsc (MemRangeFld, MemRangeFld->getType ());
768
+
769
+ FieldDecl *OffsetFld = getFieldDeclByName (RecordDecl, {" impl" , " Offset" });
770
+ assert (OffsetFld && " The accessor.impl must contain the Offset field" );
771
+ CreateAndAddPrmDsc (OffsetFld, OffsetFld->getType ());
772
+ };
773
+
774
+ std::function<void (const FieldDecl *, const QualType &ArgTy)>
775
+ createParamDescForWrappedAccessors =
776
+ [&](const FieldDecl *Fld, const QualType &ArgTy) {
777
+ const auto *Wrapper = ArgTy->getAsCXXRecordDecl ();
778
+ for (const auto *WrapperFld : Wrapper->fields ()) {
779
+ QualType FldType = WrapperFld->getType ();
780
+ if (FldType->isStructureOrClassType ()) {
781
+ if (Util::isSyclAccessorType (FldType)) {
782
+ // accessor field is found - create descriptor
783
+ createAccessorParamDesc (WrapperFld, FldType);
784
+ } else {
785
+ // field is some class or struct - recursively check for
786
+ // accessor fields
787
+ createParamDescForWrappedAccessors (WrapperFld, FldType);
788
+ }
789
+ }
790
+ }
791
+ };
792
+
678
793
for (const auto *Fld : KernelObj->fields ()) {
679
794
QualType ArgTy = Fld->getType ();
680
795
if (Util::isSyclAccessorType (ArgTy)) {
681
- // the parameter is a SYCL accessor object
682
- const auto *RecordDecl = ArgTy->getAsCXXRecordDecl ();
683
- assert (RecordDecl && " accessor must be of a record type" );
684
- const auto *TemplateDecl =
685
- cast<ClassTemplateSpecializationDecl>(RecordDecl);
686
- // First accessor template parameter - data type
687
- QualType PointeeType = TemplateDecl->getTemplateArgs ()[0 ].getAsType ();
688
- // Fourth parameter - access target
689
- target AccessTarget = getAccessTarget (TemplateDecl);
690
- Qualifiers Quals = PointeeType.getQualifiers ();
691
- // TODO: Support all access targets
692
- switch (AccessTarget) {
693
- case target::global_buffer:
694
- Quals.setAddressSpace (LangAS::opencl_global);
695
- break ;
696
- case target::constant_buffer:
697
- Quals.setAddressSpace (LangAS::opencl_constant);
698
- break ;
699
- case target::local:
700
- Quals.setAddressSpace (LangAS::opencl_local);
701
- break ;
702
- default :
703
- llvm_unreachable (" Unsupported access target" );
704
- }
705
- // TODO: get address space from accessor template parameter.
706
- PointeeType =
707
- Context.getQualifiedType (PointeeType.getUnqualifiedType (), Quals);
708
- QualType PointerType = Context.getPointerType (PointeeType);
709
-
710
- CreateAndAddPrmDsc (Fld, PointerType);
711
-
712
- FieldDecl *AccessRangeFld =
713
- getFieldDeclByName (RecordDecl, {" impl" , " AccessRange" });
714
- assert (AccessRangeFld &&
715
- " The accessor.impl must contain the AccessRange field" );
716
- CreateAndAddPrmDsc (AccessRangeFld, AccessRangeFld->getType ());
717
-
718
- FieldDecl *MemRangeFld =
719
- getFieldDeclByName (RecordDecl, {" impl" , " MemRange" });
720
- assert (MemRangeFld &&
721
- " The accessor.impl must contain the MemRange field" );
722
- CreateAndAddPrmDsc (MemRangeFld, MemRangeFld->getType ());
723
-
724
- FieldDecl *OffsetFld =
725
- getFieldDeclByName (RecordDecl, {" impl" , " Offset" });
726
- assert (OffsetFld && " The accessor.impl must contain the Offset field" );
727
- CreateAndAddPrmDsc (OffsetFld, OffsetFld->getType ());
796
+ createAccessorParamDesc (Fld, ArgTy);
728
797
} else if (Util::isSyclSamplerType (ArgTy)) {
729
798
// the parameter is a SYCL sampler object
730
799
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl ();
@@ -747,6 +816,8 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
747
816
}
748
817
// structure or class typed parameter - the same handling as a scalar
749
818
CreateAndAddPrmDsc (Fld, ArgTy);
819
+ // create descriptors for each accessor field in the class or struct
820
+ createParamDescForWrappedAccessors (Fld, ArgTy);
750
821
} else if (ArgTy->isScalarType ()) {
751
822
// scalar typed parameter
752
823
CreateAndAddPrmDsc (Fld, ArgTy);
@@ -770,14 +841,7 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
770
841
const ASTRecordLayout &Layout = Ctx.getASTRecordLayout (KernelObjTy);
771
842
H.startKernel (Name, NameType);
772
843
773
- for (const auto Fld : KernelObjTy->fields ()) {
774
- QualType ActualArgType;
775
- QualType ArgTy = Fld->getType ();
776
-
777
- // Get offset in bytes
778
- uint64_t Offset = Layout.getFieldOffset (Fld->getFieldIndex ()) / 8 ;
779
-
780
- if (Util::isSyclAccessorType (ArgTy)) {
844
+ auto populateHeaderForAccessor = [&](const QualType &ArgTy, uint64_t Offset) {
781
845
// The parameter is a SYCL accessor object.
782
846
// The Info field of the parameter descriptor for accessor contains
783
847
// two template parameters packed into thid integer field:
@@ -790,6 +854,43 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
790
854
AccTmplTy->getTemplateArgs ()[1 ].getAsIntegral ().getExtValue ());
791
855
int Info = getAccessTarget (AccTmplTy) | (Dims << 11 );
792
856
H.addParamDesc (SYCLIntegrationHeader::kind_accessor, Info, Offset);
857
+ };
858
+
859
+ std::function<void (const QualType &, uint64_t Offset)>
860
+ populateHeaderForWrappedAccessors = [&](const QualType &ArgTy,
861
+ uint64_t Offset) {
862
+ const auto *Wrapper = ArgTy->getAsCXXRecordDecl ();
863
+ for (const auto *WrapperFld : Wrapper->fields ()) {
864
+ QualType FldType = WrapperFld->getType ();
865
+ if (FldType->isStructureOrClassType ()) {
866
+ ASTContext &WrapperCtx = Wrapper->getASTContext ();
867
+ const ASTRecordLayout &WrapperLayout =
868
+ WrapperCtx.getASTRecordLayout (Wrapper);
869
+ // Get offset (in bytes) of the field in wrapper class or struct
870
+ uint64_t OffsetInWrapper =
871
+ WrapperLayout.getFieldOffset (WrapperFld->getFieldIndex ()) / 8 ;
872
+ if (Util::isSyclAccessorType (FldType)) {
873
+ // This is an accesor - populate the header appropriately
874
+ populateHeaderForAccessor (FldType, Offset + OffsetInWrapper);
875
+ } else {
876
+ // This is an other class or struct - recursively search for an
877
+ // accessor field
878
+ populateHeaderForWrappedAccessors (FldType,
879
+ Offset + OffsetInWrapper);
880
+ }
881
+ }
882
+ }
883
+ };
884
+
885
+ for (const auto Fld : KernelObjTy->fields ()) {
886
+ QualType ActualArgType;
887
+ QualType ArgTy = Fld->getType ();
888
+
889
+ // Get offset in bytes
890
+ uint64_t Offset = Layout.getFieldOffset (Fld->getFieldIndex ()) / 8 ;
891
+
892
+ if (Util::isSyclAccessorType (ArgTy)) {
893
+ populateHeaderForAccessor (ArgTy, Offset);
793
894
} else if (Util::isSyclSamplerType (ArgTy)) {
794
895
// The parameter is a SYCL sampler object
795
896
// It has only one descriptor, "m_Sampler"
@@ -810,6 +911,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
810
911
uint64_t Sz = Ctx.getTypeSizeInChars (Fld->getType ()).getQuantity ();
811
912
H.addParamDesc (SYCLIntegrationHeader::kind_std_layout,
812
913
static_cast <unsigned >(Sz), static_cast <unsigned >(Offset));
914
+
915
+ // check for accessor fields in structure or class and populate the
916
+ // integration header appropriately
917
+ if (ArgTy->isStructureOrClassType ()) {
918
+ populateHeaderForWrappedAccessors (ArgTy, Offset);
919
+ }
813
920
} else {
814
921
llvm_unreachable (" unsupported kernel parameter type" );
815
922
}
0 commit comments