Skip to content

[SYCL] Support the cases when accessor is wrapped to some class #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 170 additions & 63 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
return Res;
};

QualType FieldType = Field->getType();
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl();
if (CRD && Util::isSyclAccessorType(FieldType)) {
auto getExprForAccessorInit = [&](const QualType &paramTy,
FieldDecl *Field,
const CXXRecordDecl *CRD, Expr *Base) {
// Since this is an accessor next 4 TargetFuncParams including current
// should be set in __init method: _ValueType*, range<int>, range<int>,
// id<int>
Expand All @@ -472,9 +472,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
std::advance(TargetFuncParam, NumParams - 1);

DeclAccessPair FieldDAP = DeclAccessPair::make(Field, AS_none);
// kernel_obj.accessor
// [kenrel_obj or wrapper object].accessor
auto AccessorME = MemberExpr::Create(
S.Context, CloneRef, false, SourceLocation(),
S.Context, Base, false, SourceLocation(),
NestedNameSpecifierLoc(), SourceLocation(), Field, FieldDAP,
DeclarationNameInfo(Field->getDeclName(), SourceLocation()),
nullptr, Field->getType(), VK_LValue, OK_Ordinary);
Expand All @@ -488,7 +488,7 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
}
assert(InitMethod && "The accessor must have the __init method");

// kernel_obj.accessor.__init
// [kenrel_obj or wrapper object].accessor.__init
DeclAccessPair MethodDAP = DeclAccessPair::make(InitMethod, AS_none);
auto ME = MemberExpr::Create(
S.Context, AccessorME, false, SourceLocation(),
Expand All @@ -515,11 +515,52 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
S, ((*ParamItr++))->getOriginalType(), ParamDREs[2]));
ParamStmts.push_back(getExprForRangeOrOffset(
S, ((*ParamItr++))->getOriginalType(), ParamDREs[3]));
// kernel_obj.accessor.__init(_ValueType*, range<int>, range<int>,
// id<int>)
// [kenrel_obj or wrapper object].accessor.__init(_ValueType*,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kenrel_obj ?

// range<int>, range<int>, id<int>)
CXXMemberCallExpr *Call = CXXMemberCallExpr::Create(
S.Context, ME, ParamStmts, ResultTy, VK, SourceLocation());
BodyStmts.push_back(Call);
};

// Recursively search for accessor fields to initialize them with kernel
// parameters
std::function<void(const CXXRecordDecl *, Expr *)>
getExprForWrappedAccessorInit = [&](const CXXRecordDecl *CRD,
Expr *Base) {
for (auto *WrapperFld : CRD->fields()) {
QualType FldType = WrapperFld->getType();
CXXRecordDecl *WrapperFldCRD = FldType->getAsCXXRecordDecl();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto instead of CXXRecordDecl since the type is obvious?

if (FldType->isStructureOrClassType()) {
if (Util::isSyclAccessorType(FldType)) {
// Accessor field found - create expr to initialize this
// accessor object. Need to start from the next target
// function parameter, since current one is the wrapper object
// or parameter of the previous processed accessor object.
TargetFuncParam++;
getExprForAccessorInit(FldType, WrapperFld, WrapperFldCRD,
Base);
} else {
// Field is a structure or class so change the wrapper object
// and recursively search for accessor field.
DeclAccessPair WrapperFieldDAP =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto instead of DeclAccessPair since the type is obvious?

DeclAccessPair::make(WrapperFld, AS_none);
auto NewBase = MemberExpr::Create(
S.Context, Base, false, SourceLocation(),
NestedNameSpecifierLoc(), SourceLocation(), WrapperFld,
WrapperFieldDAP,
DeclarationNameInfo(WrapperFld->getDeclName(),
SourceLocation()),
nullptr, WrapperFld->getType(), VK_LValue, OK_Ordinary);
getExprForWrappedAccessorInit(WrapperFldCRD, NewBase);
}
}
}
};

QualType FieldType = Field->getType();
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto instead of CXXRecordDecl since the type is obvious?

if (Util::isSyclAccessorType(FieldType)) {
getExprForAccessorInit(FieldType, Field, CRD, CloneRef);
} else if (CRD && Util::isSyclSamplerType(FieldType)) {

// Sampler has only one TargetFuncParam, which should be set in
Expand Down Expand Up @@ -596,6 +637,12 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
BinaryOperator(Lhs, Rhs, BO_Assign, FieldType, VK_LValue,
OK_Ordinary, SourceLocation(), FPOptions());
BodyStmts.push_back(Res);

// If a structure/class type has accessor fields then we need to
// initialize these accessors in proper way by calling __init method of
// the accessor and passing corresponding kernel parameters.
if (CRD)
getExprForWrappedAccessorInit(CRD, Lhs);
} else {
llvm_unreachable("unsupported field type");
}
Expand Down Expand Up @@ -675,56 +722,78 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
// create a parameter descriptor and append it to the result
ParamDescs.push_back(makeParamDesc(Fld, ArgType));
};

auto createAccessorParamDesc = [&](const FieldDecl *Fld,
const QualType &ArgTy) {
// the parameter is a SYCL accessor object
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl();
assert(RecordDecl && "accessor must be of a record type");
const auto *TemplateDecl =
cast<ClassTemplateSpecializationDecl>(RecordDecl);
// First accessor template parameter - data type
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType();
// Fourth parameter - access target
target AccessTarget = getAccessTarget(TemplateDecl);
Qualifiers Quals = PointeeType.getQualifiers();
// TODO: Support all access targets
switch (AccessTarget) {
case target::global_buffer:
Quals.setAddressSpace(LangAS::opencl_global);
break;
case target::constant_buffer:
Quals.setAddressSpace(LangAS::opencl_constant);
break;
case target::local:
Quals.setAddressSpace(LangAS::opencl_local);
break;
default:
llvm_unreachable("Unsupported access target");
}
PointeeType =
Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals);
QualType PointerType = Context.getPointerType(PointeeType);

CreateAndAddPrmDsc(Fld, PointerType);

FieldDecl *AccessRangeFld =
getFieldDeclByName(RecordDecl, {"impl", "AccessRange"});
assert(AccessRangeFld &&
"The accessor.impl must contain the AccessRange field");
CreateAndAddPrmDsc(AccessRangeFld, AccessRangeFld->getType());

FieldDecl *MemRangeFld =
getFieldDeclByName(RecordDecl, {"impl", "MemRange"});
assert(MemRangeFld && "The accessor.impl must contain the MemRange field");
CreateAndAddPrmDsc(MemRangeFld, MemRangeFld->getType());

FieldDecl *OffsetFld = getFieldDeclByName(RecordDecl, {"impl", "Offset"});
assert(OffsetFld && "The accessor.impl must contain the Offset field");
CreateAndAddPrmDsc(OffsetFld, OffsetFld->getType());
};

std::function<void(const FieldDecl *, const QualType &ArgTy)>
createParamDescForWrappedAccessors =
[&](const FieldDecl *Fld, const QualType &ArgTy) {
const auto *Wrapper = ArgTy->getAsCXXRecordDecl();
for (const auto *WrapperFld : Wrapper->fields()) {
QualType FldType = WrapperFld->getType();
if (FldType->isStructureOrClassType()) {
if (Util::isSyclAccessorType(FldType)) {
// accessor field is found - create descriptor
createAccessorParamDesc(WrapperFld, FldType);
} else {
// field is some class or struct - recursively check for
// accessor fields
createParamDescForWrappedAccessors(WrapperFld, FldType);
}
}
}
};

for (const auto *Fld : KernelObj->fields()) {
QualType ArgTy = Fld->getType();
if (Util::isSyclAccessorType(ArgTy)) {
// the parameter is a SYCL accessor object
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl();
assert(RecordDecl && "accessor must be of a record type");
const auto *TemplateDecl =
cast<ClassTemplateSpecializationDecl>(RecordDecl);
// First accessor template parameter - data type
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType();
// Fourth parameter - access target
target AccessTarget = getAccessTarget(TemplateDecl);
Qualifiers Quals = PointeeType.getQualifiers();
// TODO: Support all access targets
switch (AccessTarget) {
case target::global_buffer:
Quals.setAddressSpace(LangAS::opencl_global);
break;
case target::constant_buffer:
Quals.setAddressSpace(LangAS::opencl_constant);
break;
case target::local:
Quals.setAddressSpace(LangAS::opencl_local);
break;
default:
llvm_unreachable("Unsupported access target");
}
// TODO: get address space from accessor template parameter.
PointeeType =
Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals);
QualType PointerType = Context.getPointerType(PointeeType);

CreateAndAddPrmDsc(Fld, PointerType);

FieldDecl *AccessRangeFld =
getFieldDeclByName(RecordDecl, {"impl", "AccessRange"});
assert(AccessRangeFld &&
"The accessor.impl must contain the AccessRange field");
CreateAndAddPrmDsc(AccessRangeFld, AccessRangeFld->getType());

FieldDecl *MemRangeFld =
getFieldDeclByName(RecordDecl, {"impl", "MemRange"});
assert(MemRangeFld &&
"The accessor.impl must contain the MemRange field");
CreateAndAddPrmDsc(MemRangeFld, MemRangeFld->getType());

FieldDecl *OffsetFld =
getFieldDeclByName(RecordDecl, {"impl", "Offset"});
assert(OffsetFld && "The accessor.impl must contain the Offset field");
CreateAndAddPrmDsc(OffsetFld, OffsetFld->getType());
createAccessorParamDesc(Fld, ArgTy);
} else if (Util::isSyclSamplerType(ArgTy)) {
// the parameter is a SYCL sampler object
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl();
Expand All @@ -747,6 +816,8 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
}
// structure or class typed parameter - the same handling as a scalar
CreateAndAddPrmDsc(Fld, ArgTy);
// create descriptors for each accessor field in the class or struct
createParamDescForWrappedAccessors(Fld, ArgTy);
} else if (ArgTy->isScalarType()) {
// scalar typed parameter
CreateAndAddPrmDsc(Fld, ArgTy);
Expand All @@ -770,14 +841,7 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
const ASTRecordLayout &Layout = Ctx.getASTRecordLayout(KernelObjTy);
H.startKernel(Name, NameType);

for (const auto Fld : KernelObjTy->fields()) {
QualType ActualArgType;
QualType ArgTy = Fld->getType();

// Get offset in bytes
uint64_t Offset = Layout.getFieldOffset(Fld->getFieldIndex()) / 8;

if (Util::isSyclAccessorType(ArgTy)) {
auto populateHeaderForAccessor = [&](const QualType &ArgTy, uint64_t Offset) {
// The parameter is a SYCL accessor object.
// The Info field of the parameter descriptor for accessor contains
// two template parameters packed into thid integer field:
Expand All @@ -790,6 +854,43 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
AccTmplTy->getTemplateArgs()[1].getAsIntegral().getExtValue());
int Info = getAccessTarget(AccTmplTy) | (Dims << 11);
H.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info, Offset);
};

std::function<void(const QualType &, uint64_t Offset)>
populateHeaderForWrappedAccessors = [&](const QualType &ArgTy,
uint64_t Offset) {
const auto *Wrapper = ArgTy->getAsCXXRecordDecl();
for (const auto *WrapperFld : Wrapper->fields()) {
QualType FldType = WrapperFld->getType();
if (FldType->isStructureOrClassType()) {
ASTContext &WrapperCtx = Wrapper->getASTContext();
const ASTRecordLayout &WrapperLayout =
WrapperCtx.getASTRecordLayout(Wrapper);
// Get offset (in bytes) of the field in wrapper class or struct
uint64_t OffsetInWrapper =
WrapperLayout.getFieldOffset(WrapperFld->getFieldIndex()) / 8;
if (Util::isSyclAccessorType(FldType)) {
// This is an accesor - populate the header appropriately
populateHeaderForAccessor(FldType, Offset + OffsetInWrapper);
} else {
// This is an other class or struct - recursively search for an
// accessor field
populateHeaderForWrappedAccessors(FldType,
Offset + OffsetInWrapper);
}
}
}
};

for (const auto Fld : KernelObjTy->fields()) {
QualType ActualArgType;
QualType ArgTy = Fld->getType();

// Get offset in bytes
uint64_t Offset = Layout.getFieldOffset(Fld->getFieldIndex()) / 8;

if (Util::isSyclAccessorType(ArgTy)) {
populateHeaderForAccessor(ArgTy, Offset);
} else if (Util::isSyclSamplerType(ArgTy)) {
// The parameter is a SYCL sampler object
// It has only one descriptor, "m_Sampler"
Expand All @@ -810,6 +911,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
uint64_t Sz = Ctx.getTypeSizeInChars(Fld->getType()).getQuantity();
H.addParamDesc(SYCLIntegrationHeader::kind_std_layout,
static_cast<unsigned>(Sz), static_cast<unsigned>(Offset));

// check for accessor fields in structure or class and populate the
// integration header appropriately
if (ArgTy->isStructureOrClassType()) {
populateHeaderForWrappedAccessors(ArgTy, Offset);
}
} else {
llvm_unreachable("unsupported kernel parameter type");
}
Expand Down
51 changes: 51 additions & 0 deletions clang/test/CodeGenSYCL/wrapped-accessor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: %clang -I %S/Inputs --sycl -Xclang -fsycl-int-header=%t.h %s -c -o %T/kernel.spv
// RUN: FileCheck -input-file=%t.h %s
//
// CHECK: #include <CL/sycl/detail/kernel_desc.hpp>

// CHECK: class wrapped_access;

// CHECK: namespace cl {
// CHECK-NEXT: namespace sycl {
// CHECK-NEXT: namespace detail {

// CHECK: static constexpr
// CHECK-NEXT: const char* const kernel_names[] = {
// CHECK-NEXT: "_ZTSZ4mainE14wrapped_access"
// CHECK-NEXT: };

// CHECK: static constexpr
// CHECK-NEXT: const kernel_param_desc_t kernel_signatures[] = {
// CHECK-NEXT: //--- _ZTSZ4mainE14wrapped_access
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 3, 0 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 0 },
// CHECK-EMPTY:
// CHECK-NEXT: };

// CHECK: static constexpr
// CHECK-NEXT: const unsigned kernel_signature_start[] = {
// CHECK-NEXT: 0 // _ZTSZ4mainE14wrapped_access
// CHECK-NEXT: };

// CHECK: template <class KernelNameType> struct KernelInfo;

// CHECK: template <> struct KernelInfo<class wrapped_access> {

#include <sycl.hpp>

template <typename Acc>
struct AccWrapper { Acc accessor; };

template <typename name, typename Func>
__attribute__((sycl_kernel)) void kernel(Func kernelFunc) {
kernelFunc();
}

int main() {
cl::sycl::accessor<int, 1, cl::sycl::access::mode::read_write> acc;
auto acc_wrapped = AccWrapper<decltype(acc)>{acc};
kernel<class wrapped_access>(
[=]() {
acc_wrapped.accessor.use();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I appreciate this feature, I have the feeling it is a kind of magical SYCL extension...
To discuss inside the SYCL committee...

});
}
Loading