Skip to content

Commit 223adac

Browse files
againullromanovvlad
authored andcommitted
[SYCL] Support the cases when accessor is wrapped to some class
Current implementation supports only the cases when captured accessor is a top level object. But accessors could be wrapped to some classes. Frontend should properly handle this case and intialize accessors using appropriate kernel parameters. Signed-off-by: Artur Gainullin <[email protected]>
1 parent fc9bcc5 commit 223adac

File tree

4 files changed

+391
-63
lines changed

4 files changed

+391
-63
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 170 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
454454
return Res;
455455
};
456456

457-
QualType FieldType = Field->getType();
458-
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl();
459-
if (CRD && Util::isSyclAccessorType(FieldType)) {
457+
auto getExprForAccessorInit = [&](const QualType &paramTy,
458+
FieldDecl *Field,
459+
const CXXRecordDecl *CRD, Expr *Base) {
460460
// Since this is an accessor next 4 TargetFuncParams including current
461461
// should be set in __init method: _ValueType*, range<int>, range<int>,
462462
// id<int>
@@ -472,9 +472,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
472472
std::advance(TargetFuncParam, NumParams - 1);
473473

474474
DeclAccessPair FieldDAP = DeclAccessPair::make(Field, AS_none);
475-
// kernel_obj.accessor
475+
// [kenrel_obj or wrapper object].accessor
476476
auto AccessorME = MemberExpr::Create(
477-
S.Context, CloneRef, false, SourceLocation(),
477+
S.Context, Base, false, SourceLocation(),
478478
NestedNameSpecifierLoc(), SourceLocation(), Field, FieldDAP,
479479
DeclarationNameInfo(Field->getDeclName(), SourceLocation()),
480480
nullptr, Field->getType(), VK_LValue, OK_Ordinary);
@@ -488,7 +488,7 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
488488
}
489489
assert(InitMethod && "The accessor must have the __init method");
490490

491-
// kernel_obj.accessor.__init
491+
// [kenrel_obj or wrapper object].accessor.__init
492492
DeclAccessPair MethodDAP = DeclAccessPair::make(InitMethod, AS_none);
493493
auto ME = MemberExpr::Create(
494494
S.Context, AccessorME, false, SourceLocation(),
@@ -515,11 +515,52 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
515515
S, ((*ParamItr++))->getOriginalType(), ParamDREs[2]));
516516
ParamStmts.push_back(getExprForRangeOrOffset(
517517
S, ((*ParamItr++))->getOriginalType(), ParamDREs[3]));
518-
// kernel_obj.accessor.__init(_ValueType*, range<int>, range<int>,
519-
// id<int>)
518+
// [kenrel_obj or wrapper object].accessor.__init(_ValueType*,
519+
// range<int>, range<int>, id<int>)
520520
CXXMemberCallExpr *Call = CXXMemberCallExpr::Create(
521521
S.Context, ME, ParamStmts, ResultTy, VK, SourceLocation());
522522
BodyStmts.push_back(Call);
523+
};
524+
525+
// Recursively search for accessor fields to initialize them with kernel
526+
// parameters
527+
std::function<void(const CXXRecordDecl *, Expr *)>
528+
getExprForWrappedAccessorInit = [&](const CXXRecordDecl *CRD,
529+
Expr *Base) {
530+
for (auto *WrapperFld : CRD->fields()) {
531+
QualType FldType = WrapperFld->getType();
532+
CXXRecordDecl *WrapperFldCRD = FldType->getAsCXXRecordDecl();
533+
if (FldType->isStructureOrClassType()) {
534+
if (Util::isSyclAccessorType(FldType)) {
535+
// Accessor field found - create expr to initialize this
536+
// accessor object. Need to start from the next target
537+
// function parameter, since current one is the wrapper object
538+
// or parameter of the previous processed accessor object.
539+
TargetFuncParam++;
540+
getExprForAccessorInit(FldType, WrapperFld, WrapperFldCRD,
541+
Base);
542+
} else {
543+
// Field is a structure or class so change the wrapper object
544+
// and recursively search for accessor field.
545+
DeclAccessPair WrapperFieldDAP =
546+
DeclAccessPair::make(WrapperFld, AS_none);
547+
auto NewBase = MemberExpr::Create(
548+
S.Context, Base, false, SourceLocation(),
549+
NestedNameSpecifierLoc(), SourceLocation(), WrapperFld,
550+
WrapperFieldDAP,
551+
DeclarationNameInfo(WrapperFld->getDeclName(),
552+
SourceLocation()),
553+
nullptr, WrapperFld->getType(), VK_LValue, OK_Ordinary);
554+
getExprForWrappedAccessorInit(WrapperFldCRD, NewBase);
555+
}
556+
}
557+
}
558+
};
559+
560+
QualType FieldType = Field->getType();
561+
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl();
562+
if (Util::isSyclAccessorType(FieldType)) {
563+
getExprForAccessorInit(FieldType, Field, CRD, CloneRef);
523564
} else if (CRD && Util::isSyclSamplerType(FieldType)) {
524565

525566
// Sampler has only one TargetFuncParam, which should be set in
@@ -596,6 +637,12 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
596637
BinaryOperator(Lhs, Rhs, BO_Assign, FieldType, VK_LValue,
597638
OK_Ordinary, SourceLocation(), FPOptions());
598639
BodyStmts.push_back(Res);
640+
641+
// If a structure/class type has accessor fields then we need to
642+
// initialize these accessors in proper way by calling __init method of
643+
// the accessor and passing corresponding kernel parameters.
644+
if (CRD)
645+
getExprForWrappedAccessorInit(CRD, Lhs);
599646
} else {
600647
llvm_unreachable("unsupported field type");
601648
}
@@ -675,56 +722,78 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
675722
// create a parameter descriptor and append it to the result
676723
ParamDescs.push_back(makeParamDesc(Fld, ArgType));
677724
};
725+
726+
auto createAccessorParamDesc = [&](const FieldDecl *Fld,
727+
const QualType &ArgTy) {
728+
// the parameter is a SYCL accessor object
729+
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl();
730+
assert(RecordDecl && "accessor must be of a record type");
731+
const auto *TemplateDecl =
732+
cast<ClassTemplateSpecializationDecl>(RecordDecl);
733+
// First accessor template parameter - data type
734+
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType();
735+
// Fourth parameter - access target
736+
target AccessTarget = getAccessTarget(TemplateDecl);
737+
Qualifiers Quals = PointeeType.getQualifiers();
738+
// TODO: Support all access targets
739+
switch (AccessTarget) {
740+
case target::global_buffer:
741+
Quals.setAddressSpace(LangAS::opencl_global);
742+
break;
743+
case target::constant_buffer:
744+
Quals.setAddressSpace(LangAS::opencl_constant);
745+
break;
746+
case target::local:
747+
Quals.setAddressSpace(LangAS::opencl_local);
748+
break;
749+
default:
750+
llvm_unreachable("Unsupported access target");
751+
}
752+
PointeeType =
753+
Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals);
754+
QualType PointerType = Context.getPointerType(PointeeType);
755+
756+
CreateAndAddPrmDsc(Fld, PointerType);
757+
758+
FieldDecl *AccessRangeFld =
759+
getFieldDeclByName(RecordDecl, {"impl", "AccessRange"});
760+
assert(AccessRangeFld &&
761+
"The accessor.impl must contain the AccessRange field");
762+
CreateAndAddPrmDsc(AccessRangeFld, AccessRangeFld->getType());
763+
764+
FieldDecl *MemRangeFld =
765+
getFieldDeclByName(RecordDecl, {"impl", "MemRange"});
766+
assert(MemRangeFld && "The accessor.impl must contain the MemRange field");
767+
CreateAndAddPrmDsc(MemRangeFld, MemRangeFld->getType());
768+
769+
FieldDecl *OffsetFld = getFieldDeclByName(RecordDecl, {"impl", "Offset"});
770+
assert(OffsetFld && "The accessor.impl must contain the Offset field");
771+
CreateAndAddPrmDsc(OffsetFld, OffsetFld->getType());
772+
};
773+
774+
std::function<void(const FieldDecl *, const QualType &ArgTy)>
775+
createParamDescForWrappedAccessors =
776+
[&](const FieldDecl *Fld, const QualType &ArgTy) {
777+
const auto *Wrapper = ArgTy->getAsCXXRecordDecl();
778+
for (const auto *WrapperFld : Wrapper->fields()) {
779+
QualType FldType = WrapperFld->getType();
780+
if (FldType->isStructureOrClassType()) {
781+
if (Util::isSyclAccessorType(FldType)) {
782+
// accessor field is found - create descriptor
783+
createAccessorParamDesc(WrapperFld, FldType);
784+
} else {
785+
// field is some class or struct - recursively check for
786+
// accessor fields
787+
createParamDescForWrappedAccessors(WrapperFld, FldType);
788+
}
789+
}
790+
}
791+
};
792+
678793
for (const auto *Fld : KernelObj->fields()) {
679794
QualType ArgTy = Fld->getType();
680795
if (Util::isSyclAccessorType(ArgTy)) {
681-
// the parameter is a SYCL accessor object
682-
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl();
683-
assert(RecordDecl && "accessor must be of a record type");
684-
const auto *TemplateDecl =
685-
cast<ClassTemplateSpecializationDecl>(RecordDecl);
686-
// First accessor template parameter - data type
687-
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType();
688-
// Fourth parameter - access target
689-
target AccessTarget = getAccessTarget(TemplateDecl);
690-
Qualifiers Quals = PointeeType.getQualifiers();
691-
// TODO: Support all access targets
692-
switch (AccessTarget) {
693-
case target::global_buffer:
694-
Quals.setAddressSpace(LangAS::opencl_global);
695-
break;
696-
case target::constant_buffer:
697-
Quals.setAddressSpace(LangAS::opencl_constant);
698-
break;
699-
case target::local:
700-
Quals.setAddressSpace(LangAS::opencl_local);
701-
break;
702-
default:
703-
llvm_unreachable("Unsupported access target");
704-
}
705-
// TODO: get address space from accessor template parameter.
706-
PointeeType =
707-
Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals);
708-
QualType PointerType = Context.getPointerType(PointeeType);
709-
710-
CreateAndAddPrmDsc(Fld, PointerType);
711-
712-
FieldDecl *AccessRangeFld =
713-
getFieldDeclByName(RecordDecl, {"impl", "AccessRange"});
714-
assert(AccessRangeFld &&
715-
"The accessor.impl must contain the AccessRange field");
716-
CreateAndAddPrmDsc(AccessRangeFld, AccessRangeFld->getType());
717-
718-
FieldDecl *MemRangeFld =
719-
getFieldDeclByName(RecordDecl, {"impl", "MemRange"});
720-
assert(MemRangeFld &&
721-
"The accessor.impl must contain the MemRange field");
722-
CreateAndAddPrmDsc(MemRangeFld, MemRangeFld->getType());
723-
724-
FieldDecl *OffsetFld =
725-
getFieldDeclByName(RecordDecl, {"impl", "Offset"});
726-
assert(OffsetFld && "The accessor.impl must contain the Offset field");
727-
CreateAndAddPrmDsc(OffsetFld, OffsetFld->getType());
796+
createAccessorParamDesc(Fld, ArgTy);
728797
} else if (Util::isSyclSamplerType(ArgTy)) {
729798
// the parameter is a SYCL sampler object
730799
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl();
@@ -747,6 +816,8 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
747816
}
748817
// structure or class typed parameter - the same handling as a scalar
749818
CreateAndAddPrmDsc(Fld, ArgTy);
819+
// create descriptors for each accessor field in the class or struct
820+
createParamDescForWrappedAccessors(Fld, ArgTy);
750821
} else if (ArgTy->isScalarType()) {
751822
// scalar typed parameter
752823
CreateAndAddPrmDsc(Fld, ArgTy);
@@ -770,14 +841,7 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
770841
const ASTRecordLayout &Layout = Ctx.getASTRecordLayout(KernelObjTy);
771842
H.startKernel(Name, NameType);
772843

773-
for (const auto Fld : KernelObjTy->fields()) {
774-
QualType ActualArgType;
775-
QualType ArgTy = Fld->getType();
776-
777-
// Get offset in bytes
778-
uint64_t Offset = Layout.getFieldOffset(Fld->getFieldIndex()) / 8;
779-
780-
if (Util::isSyclAccessorType(ArgTy)) {
844+
auto populateHeaderForAccessor = [&](const QualType &ArgTy, uint64_t Offset) {
781845
// The parameter is a SYCL accessor object.
782846
// The Info field of the parameter descriptor for accessor contains
783847
// two template parameters packed into thid integer field:
@@ -790,6 +854,43 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
790854
AccTmplTy->getTemplateArgs()[1].getAsIntegral().getExtValue());
791855
int Info = getAccessTarget(AccTmplTy) | (Dims << 11);
792856
H.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info, Offset);
857+
};
858+
859+
std::function<void(const QualType &, uint64_t Offset)>
860+
populateHeaderForWrappedAccessors = [&](const QualType &ArgTy,
861+
uint64_t Offset) {
862+
const auto *Wrapper = ArgTy->getAsCXXRecordDecl();
863+
for (const auto *WrapperFld : Wrapper->fields()) {
864+
QualType FldType = WrapperFld->getType();
865+
if (FldType->isStructureOrClassType()) {
866+
ASTContext &WrapperCtx = Wrapper->getASTContext();
867+
const ASTRecordLayout &WrapperLayout =
868+
WrapperCtx.getASTRecordLayout(Wrapper);
869+
// Get offset (in bytes) of the field in wrapper class or struct
870+
uint64_t OffsetInWrapper =
871+
WrapperLayout.getFieldOffset(WrapperFld->getFieldIndex()) / 8;
872+
if (Util::isSyclAccessorType(FldType)) {
873+
// This is an accesor - populate the header appropriately
874+
populateHeaderForAccessor(FldType, Offset + OffsetInWrapper);
875+
} else {
876+
// This is an other class or struct - recursively search for an
877+
// accessor field
878+
populateHeaderForWrappedAccessors(FldType,
879+
Offset + OffsetInWrapper);
880+
}
881+
}
882+
}
883+
};
884+
885+
for (const auto Fld : KernelObjTy->fields()) {
886+
QualType ActualArgType;
887+
QualType ArgTy = Fld->getType();
888+
889+
// Get offset in bytes
890+
uint64_t Offset = Layout.getFieldOffset(Fld->getFieldIndex()) / 8;
891+
892+
if (Util::isSyclAccessorType(ArgTy)) {
893+
populateHeaderForAccessor(ArgTy, Offset);
793894
} else if (Util::isSyclSamplerType(ArgTy)) {
794895
// The parameter is a SYCL sampler object
795896
// It has only one descriptor, "m_Sampler"
@@ -810,6 +911,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
810911
uint64_t Sz = Ctx.getTypeSizeInChars(Fld->getType()).getQuantity();
811912
H.addParamDesc(SYCLIntegrationHeader::kind_std_layout,
812913
static_cast<unsigned>(Sz), static_cast<unsigned>(Offset));
914+
915+
// check for accessor fields in structure or class and populate the
916+
// integration header appropriately
917+
if (ArgTy->isStructureOrClassType()) {
918+
populateHeaderForWrappedAccessors(ArgTy, Offset);
919+
}
813920
} else {
814921
llvm_unreachable("unsupported kernel parameter type");
815922
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: %clang -I %S/Inputs --sycl -Xclang -fsycl-int-header=%t.h %s -c -o %T/kernel.spv
2+
// RUN: FileCheck -input-file=%t.h %s
3+
//
4+
// CHECK: #include <CL/sycl/detail/kernel_desc.hpp>
5+
6+
// CHECK: class wrapped_access;
7+
8+
// CHECK: namespace cl {
9+
// CHECK-NEXT: namespace sycl {
10+
// CHECK-NEXT: namespace detail {
11+
12+
// CHECK: static constexpr
13+
// CHECK-NEXT: const char* const kernel_names[] = {
14+
// CHECK-NEXT: "_ZTSZ4mainE14wrapped_access"
15+
// CHECK-NEXT: };
16+
17+
// CHECK: static constexpr
18+
// CHECK-NEXT: const kernel_param_desc_t kernel_signatures[] = {
19+
// CHECK-NEXT: //--- _ZTSZ4mainE14wrapped_access
20+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 3, 0 },
21+
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 0 },
22+
// CHECK-EMPTY:
23+
// CHECK-NEXT: };
24+
25+
// CHECK: static constexpr
26+
// CHECK-NEXT: const unsigned kernel_signature_start[] = {
27+
// CHECK-NEXT: 0 // _ZTSZ4mainE14wrapped_access
28+
// CHECK-NEXT: };
29+
30+
// CHECK: template <class KernelNameType> struct KernelInfo;
31+
32+
// CHECK: template <> struct KernelInfo<class wrapped_access> {
33+
34+
#include <sycl.hpp>
35+
36+
template <typename Acc>
37+
struct AccWrapper { Acc accessor; };
38+
39+
template <typename name, typename Func>
40+
__attribute__((sycl_kernel)) void kernel(Func kernelFunc) {
41+
kernelFunc();
42+
}
43+
44+
int main() {
45+
cl::sycl::accessor<int, 1, cl::sycl::access::mode::read_write> acc;
46+
auto acc_wrapped = AccWrapper<decltype(acc)>{acc};
47+
kernel<class wrapped_access>(
48+
[=]() {
49+
acc_wrapped.accessor.use();
50+
});
51+
}

0 commit comments

Comments
 (0)