-
Notifications
You must be signed in to change notification settings - Fork 788
[SYCL] Support the cases when accessor is wrapped to some class #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -454,9 +454,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { | |
return Res; | ||
}; | ||
|
||
QualType FieldType = Field->getType(); | ||
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl(); | ||
if (CRD && Util::isSyclAccessorType(FieldType)) { | ||
auto getExprForAccessorInit = [&](const QualType ¶mTy, | ||
FieldDecl *Field, | ||
const CXXRecordDecl *CRD, Expr *Base) { | ||
// Since this is an accessor next 4 TargetFuncParams including current | ||
// should be set in __init method: _ValueType*, range<int>, range<int>, | ||
// id<int> | ||
|
@@ -472,9 +472,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { | |
std::advance(TargetFuncParam, NumParams - 1); | ||
|
||
DeclAccessPair FieldDAP = DeclAccessPair::make(Field, AS_none); | ||
// kernel_obj.accessor | ||
// [kenrel_obj or wrapper object].accessor | ||
auto AccessorME = MemberExpr::Create( | ||
S.Context, CloneRef, false, SourceLocation(), | ||
S.Context, Base, false, SourceLocation(), | ||
NestedNameSpecifierLoc(), SourceLocation(), Field, FieldDAP, | ||
DeclarationNameInfo(Field->getDeclName(), SourceLocation()), | ||
nullptr, Field->getType(), VK_LValue, OK_Ordinary); | ||
|
@@ -488,7 +488,7 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { | |
} | ||
assert(InitMethod && "The accessor must have the __init method"); | ||
|
||
// kernel_obj.accessor.__init | ||
// [kenrel_obj or wrapper object].accessor.__init | ||
DeclAccessPair MethodDAP = DeclAccessPair::make(InitMethod, AS_none); | ||
auto ME = MemberExpr::Create( | ||
S.Context, AccessorME, false, SourceLocation(), | ||
|
@@ -515,11 +515,52 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { | |
S, ((*ParamItr++))->getOriginalType(), ParamDREs[2])); | ||
ParamStmts.push_back(getExprForRangeOrOffset( | ||
S, ((*ParamItr++))->getOriginalType(), ParamDREs[3])); | ||
// kernel_obj.accessor.__init(_ValueType*, range<int>, range<int>, | ||
// id<int>) | ||
// [kenrel_obj or wrapper object].accessor.__init(_ValueType*, | ||
// range<int>, range<int>, id<int>) | ||
CXXMemberCallExpr *Call = CXXMemberCallExpr::Create( | ||
S.Context, ME, ParamStmts, ResultTy, VK, SourceLocation()); | ||
BodyStmts.push_back(Call); | ||
}; | ||
|
||
// Recursively search for accessor fields to initialize them with kernel | ||
// parameters | ||
std::function<void(const CXXRecordDecl *, Expr *)> | ||
getExprForWrappedAccessorInit = [&](const CXXRecordDecl *CRD, | ||
Expr *Base) { | ||
for (auto *WrapperFld : CRD->fields()) { | ||
QualType FldType = WrapperFld->getType(); | ||
CXXRecordDecl *WrapperFldCRD = FldType->getAsCXXRecordDecl(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if (FldType->isStructureOrClassType()) { | ||
if (Util::isSyclAccessorType(FldType)) { | ||
// Accessor field found - create expr to initialize this | ||
// accessor object. Need to start from the next target | ||
// function parameter, since current one is the wrapper object | ||
// or parameter of the previous processed accessor object. | ||
TargetFuncParam++; | ||
getExprForAccessorInit(FldType, WrapperFld, WrapperFldCRD, | ||
Base); | ||
} else { | ||
// Field is a structure or class so change the wrapper object | ||
// and recursively search for accessor field. | ||
DeclAccessPair WrapperFieldDAP = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
DeclAccessPair::make(WrapperFld, AS_none); | ||
auto NewBase = MemberExpr::Create( | ||
S.Context, Base, false, SourceLocation(), | ||
NestedNameSpecifierLoc(), SourceLocation(), WrapperFld, | ||
WrapperFieldDAP, | ||
DeclarationNameInfo(WrapperFld->getDeclName(), | ||
SourceLocation()), | ||
nullptr, WrapperFld->getType(), VK_LValue, OK_Ordinary); | ||
getExprForWrappedAccessorInit(WrapperFldCRD, NewBase); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
QualType FieldType = Field->getType(); | ||
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if (Util::isSyclAccessorType(FieldType)) { | ||
getExprForAccessorInit(FieldType, Field, CRD, CloneRef); | ||
} else if (CRD && Util::isSyclSamplerType(FieldType)) { | ||
|
||
// Sampler has only one TargetFuncParam, which should be set in | ||
|
@@ -596,6 +637,12 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) { | |
BinaryOperator(Lhs, Rhs, BO_Assign, FieldType, VK_LValue, | ||
OK_Ordinary, SourceLocation(), FPOptions()); | ||
BodyStmts.push_back(Res); | ||
|
||
// If a structure/class type has accessor fields then we need to | ||
// initialize these accessors in proper way by calling __init method of | ||
// the accessor and passing corresponding kernel parameters. | ||
if (CRD) | ||
getExprForWrappedAccessorInit(CRD, Lhs); | ||
} else { | ||
llvm_unreachable("unsupported field type"); | ||
} | ||
|
@@ -675,56 +722,78 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj, | |
// create a parameter descriptor and append it to the result | ||
ParamDescs.push_back(makeParamDesc(Fld, ArgType)); | ||
}; | ||
|
||
auto createAccessorParamDesc = [&](const FieldDecl *Fld, | ||
const QualType &ArgTy) { | ||
// the parameter is a SYCL accessor object | ||
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl(); | ||
assert(RecordDecl && "accessor must be of a record type"); | ||
const auto *TemplateDecl = | ||
cast<ClassTemplateSpecializationDecl>(RecordDecl); | ||
// First accessor template parameter - data type | ||
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType(); | ||
// Fourth parameter - access target | ||
target AccessTarget = getAccessTarget(TemplateDecl); | ||
Qualifiers Quals = PointeeType.getQualifiers(); | ||
// TODO: Support all access targets | ||
switch (AccessTarget) { | ||
case target::global_buffer: | ||
Quals.setAddressSpace(LangAS::opencl_global); | ||
break; | ||
case target::constant_buffer: | ||
Quals.setAddressSpace(LangAS::opencl_constant); | ||
break; | ||
case target::local: | ||
Quals.setAddressSpace(LangAS::opencl_local); | ||
break; | ||
default: | ||
llvm_unreachable("Unsupported access target"); | ||
} | ||
PointeeType = | ||
Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals); | ||
QualType PointerType = Context.getPointerType(PointeeType); | ||
|
||
CreateAndAddPrmDsc(Fld, PointerType); | ||
|
||
FieldDecl *AccessRangeFld = | ||
getFieldDeclByName(RecordDecl, {"impl", "AccessRange"}); | ||
assert(AccessRangeFld && | ||
"The accessor.impl must contain the AccessRange field"); | ||
CreateAndAddPrmDsc(AccessRangeFld, AccessRangeFld->getType()); | ||
|
||
FieldDecl *MemRangeFld = | ||
getFieldDeclByName(RecordDecl, {"impl", "MemRange"}); | ||
assert(MemRangeFld && "The accessor.impl must contain the MemRange field"); | ||
CreateAndAddPrmDsc(MemRangeFld, MemRangeFld->getType()); | ||
|
||
FieldDecl *OffsetFld = getFieldDeclByName(RecordDecl, {"impl", "Offset"}); | ||
assert(OffsetFld && "The accessor.impl must contain the Offset field"); | ||
CreateAndAddPrmDsc(OffsetFld, OffsetFld->getType()); | ||
}; | ||
|
||
std::function<void(const FieldDecl *, const QualType &ArgTy)> | ||
createParamDescForWrappedAccessors = | ||
[&](const FieldDecl *Fld, const QualType &ArgTy) { | ||
const auto *Wrapper = ArgTy->getAsCXXRecordDecl(); | ||
for (const auto *WrapperFld : Wrapper->fields()) { | ||
QualType FldType = WrapperFld->getType(); | ||
if (FldType->isStructureOrClassType()) { | ||
if (Util::isSyclAccessorType(FldType)) { | ||
// accessor field is found - create descriptor | ||
createAccessorParamDesc(WrapperFld, FldType); | ||
} else { | ||
// field is some class or struct - recursively check for | ||
// accessor fields | ||
createParamDescForWrappedAccessors(WrapperFld, FldType); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
for (const auto *Fld : KernelObj->fields()) { | ||
QualType ArgTy = Fld->getType(); | ||
if (Util::isSyclAccessorType(ArgTy)) { | ||
// the parameter is a SYCL accessor object | ||
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl(); | ||
assert(RecordDecl && "accessor must be of a record type"); | ||
const auto *TemplateDecl = | ||
cast<ClassTemplateSpecializationDecl>(RecordDecl); | ||
// First accessor template parameter - data type | ||
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType(); | ||
// Fourth parameter - access target | ||
target AccessTarget = getAccessTarget(TemplateDecl); | ||
Qualifiers Quals = PointeeType.getQualifiers(); | ||
// TODO: Support all access targets | ||
switch (AccessTarget) { | ||
case target::global_buffer: | ||
Quals.setAddressSpace(LangAS::opencl_global); | ||
break; | ||
case target::constant_buffer: | ||
Quals.setAddressSpace(LangAS::opencl_constant); | ||
break; | ||
case target::local: | ||
Quals.setAddressSpace(LangAS::opencl_local); | ||
break; | ||
default: | ||
llvm_unreachable("Unsupported access target"); | ||
} | ||
// TODO: get address space from accessor template parameter. | ||
PointeeType = | ||
Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals); | ||
QualType PointerType = Context.getPointerType(PointeeType); | ||
|
||
CreateAndAddPrmDsc(Fld, PointerType); | ||
|
||
FieldDecl *AccessRangeFld = | ||
getFieldDeclByName(RecordDecl, {"impl", "AccessRange"}); | ||
assert(AccessRangeFld && | ||
"The accessor.impl must contain the AccessRange field"); | ||
CreateAndAddPrmDsc(AccessRangeFld, AccessRangeFld->getType()); | ||
|
||
FieldDecl *MemRangeFld = | ||
getFieldDeclByName(RecordDecl, {"impl", "MemRange"}); | ||
assert(MemRangeFld && | ||
"The accessor.impl must contain the MemRange field"); | ||
CreateAndAddPrmDsc(MemRangeFld, MemRangeFld->getType()); | ||
|
||
FieldDecl *OffsetFld = | ||
getFieldDeclByName(RecordDecl, {"impl", "Offset"}); | ||
assert(OffsetFld && "The accessor.impl must contain the Offset field"); | ||
CreateAndAddPrmDsc(OffsetFld, OffsetFld->getType()); | ||
createAccessorParamDesc(Fld, ArgTy); | ||
} else if (Util::isSyclSamplerType(ArgTy)) { | ||
// the parameter is a SYCL sampler object | ||
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl(); | ||
|
@@ -747,6 +816,8 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj, | |
} | ||
// structure or class typed parameter - the same handling as a scalar | ||
CreateAndAddPrmDsc(Fld, ArgTy); | ||
// create descriptors for each accessor field in the class or struct | ||
createParamDescForWrappedAccessors(Fld, ArgTy); | ||
} else if (ArgTy->isScalarType()) { | ||
// scalar typed parameter | ||
CreateAndAddPrmDsc(Fld, ArgTy); | ||
|
@@ -770,14 +841,7 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name, | |
const ASTRecordLayout &Layout = Ctx.getASTRecordLayout(KernelObjTy); | ||
H.startKernel(Name, NameType); | ||
|
||
for (const auto Fld : KernelObjTy->fields()) { | ||
QualType ActualArgType; | ||
QualType ArgTy = Fld->getType(); | ||
|
||
// Get offset in bytes | ||
uint64_t Offset = Layout.getFieldOffset(Fld->getFieldIndex()) / 8; | ||
|
||
if (Util::isSyclAccessorType(ArgTy)) { | ||
auto populateHeaderForAccessor = [&](const QualType &ArgTy, uint64_t Offset) { | ||
// The parameter is a SYCL accessor object. | ||
// The Info field of the parameter descriptor for accessor contains | ||
// two template parameters packed into thid integer field: | ||
|
@@ -790,6 +854,43 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name, | |
AccTmplTy->getTemplateArgs()[1].getAsIntegral().getExtValue()); | ||
int Info = getAccessTarget(AccTmplTy) | (Dims << 11); | ||
H.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info, Offset); | ||
}; | ||
|
||
std::function<void(const QualType &, uint64_t Offset)> | ||
populateHeaderForWrappedAccessors = [&](const QualType &ArgTy, | ||
uint64_t Offset) { | ||
const auto *Wrapper = ArgTy->getAsCXXRecordDecl(); | ||
for (const auto *WrapperFld : Wrapper->fields()) { | ||
QualType FldType = WrapperFld->getType(); | ||
if (FldType->isStructureOrClassType()) { | ||
ASTContext &WrapperCtx = Wrapper->getASTContext(); | ||
const ASTRecordLayout &WrapperLayout = | ||
WrapperCtx.getASTRecordLayout(Wrapper); | ||
// Get offset (in bytes) of the field in wrapper class or struct | ||
uint64_t OffsetInWrapper = | ||
WrapperLayout.getFieldOffset(WrapperFld->getFieldIndex()) / 8; | ||
if (Util::isSyclAccessorType(FldType)) { | ||
// This is an accesor - populate the header appropriately | ||
populateHeaderForAccessor(FldType, Offset + OffsetInWrapper); | ||
} else { | ||
// This is an other class or struct - recursively search for an | ||
// accessor field | ||
populateHeaderForWrappedAccessors(FldType, | ||
Offset + OffsetInWrapper); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
for (const auto Fld : KernelObjTy->fields()) { | ||
QualType ActualArgType; | ||
QualType ArgTy = Fld->getType(); | ||
|
||
// Get offset in bytes | ||
uint64_t Offset = Layout.getFieldOffset(Fld->getFieldIndex()) / 8; | ||
|
||
if (Util::isSyclAccessorType(ArgTy)) { | ||
populateHeaderForAccessor(ArgTy, Offset); | ||
} else if (Util::isSyclSamplerType(ArgTy)) { | ||
// The parameter is a SYCL sampler object | ||
// It has only one descriptor, "m_Sampler" | ||
|
@@ -810,6 +911,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name, | |
uint64_t Sz = Ctx.getTypeSizeInChars(Fld->getType()).getQuantity(); | ||
H.addParamDesc(SYCLIntegrationHeader::kind_std_layout, | ||
static_cast<unsigned>(Sz), static_cast<unsigned>(Offset)); | ||
|
||
// check for accessor fields in structure or class and populate the | ||
// integration header appropriately | ||
if (ArgTy->isStructureOrClassType()) { | ||
populateHeaderForWrappedAccessors(ArgTy, Offset); | ||
} | ||
} else { | ||
llvm_unreachable("unsupported kernel parameter type"); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// RUN: %clang -I %S/Inputs --sycl -Xclang -fsycl-int-header=%t.h %s -c -o %T/kernel.spv | ||
// RUN: FileCheck -input-file=%t.h %s | ||
// | ||
// CHECK: #include <CL/sycl/detail/kernel_desc.hpp> | ||
|
||
// CHECK: class wrapped_access; | ||
|
||
// CHECK: namespace cl { | ||
// CHECK-NEXT: namespace sycl { | ||
// CHECK-NEXT: namespace detail { | ||
|
||
// CHECK: static constexpr | ||
// CHECK-NEXT: const char* const kernel_names[] = { | ||
// CHECK-NEXT: "_ZTSZ4mainE14wrapped_access" | ||
// CHECK-NEXT: }; | ||
|
||
// CHECK: static constexpr | ||
// CHECK-NEXT: const kernel_param_desc_t kernel_signatures[] = { | ||
// CHECK-NEXT: //--- _ZTSZ4mainE14wrapped_access | ||
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 3, 0 }, | ||
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 0 }, | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: }; | ||
|
||
// CHECK: static constexpr | ||
// CHECK-NEXT: const unsigned kernel_signature_start[] = { | ||
// CHECK-NEXT: 0 // _ZTSZ4mainE14wrapped_access | ||
// CHECK-NEXT: }; | ||
|
||
// CHECK: template <class KernelNameType> struct KernelInfo; | ||
|
||
// CHECK: template <> struct KernelInfo<class wrapped_access> { | ||
|
||
#include <sycl.hpp> | ||
|
||
template <typename Acc> | ||
struct AccWrapper { Acc accessor; }; | ||
|
||
template <typename name, typename Func> | ||
__attribute__((sycl_kernel)) void kernel(Func kernelFunc) { | ||
kernelFunc(); | ||
} | ||
|
||
int main() { | ||
cl::sycl::accessor<int, 1, cl::sycl::access::mode::read_write> acc; | ||
auto acc_wrapped = AccWrapper<decltype(acc)>{acc}; | ||
kernel<class wrapped_access>( | ||
[=]() { | ||
acc_wrapped.accessor.use(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While I appreciate this feature, I have the feeling it is a kind of magical SYCL extension... |
||
}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kenrel_obj
?