Skip to content

[WIP] Base classes handling in SemaSYCL #1861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 131 additions & 50 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,6 @@ constructKernelName(Sema &S, FunctionDecl *KernelCallerFunc,
// anonymous namespace so these don't get linkage.
namespace {

QualType getItemType(const FieldDecl *FD) { return FD->getType(); }
QualType getItemType(const CXXBaseSpecifier &BS) { return BS.getType(); }

// These enable handler execution only when previous handlers succeed.
template <typename... Tn>
static bool handleField(FieldDecl *FD, QualType FDTy, Tn &&... tn) {
Expand Down Expand Up @@ -729,11 +726,6 @@ template <typename T> using bind_param_t = typename bind_param<T>::type;
// })...)

// Implements the 'for-each-visitor' pattern.
template <typename ParentTy, typename... Handlers>
static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent,
CXXRecordDecl *Wrapper,
Handlers &... handlers);

template <typename RangeTy, typename... Handlers>
static void VisitField(CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy,
Handlers &... handlers) {
Expand All @@ -742,7 +734,7 @@ static void VisitField(CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy,
if (Util::isSyclStreamType(ItemTy))
KF_FOR_EACH(handleSyclStreamType, Item, ItemTy);
if (ItemTy->isStructureOrClassType())
VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(),
VisitRecord(Owner, Item, ItemTy->getAsCXXRecordDecl(),
handlers...);
if (ItemTy->isArrayType())
VisitArrayElements(Item, ItemTy, handlers...);
Expand All @@ -762,38 +754,68 @@ static void VisitArrayElements(RangeTy Item, QualType FieldTy,
(void)std::initializer_list<int>{(handlers.leaveArray(ET, ElemCount), 0)...};
}

template <typename RangeTy, typename... Handlers>
static void VisitAccessorWrapperHelper(CXXRecordDecl *Owner, RangeTy Range,
Handlers &... handlers) {
for (const auto &Item : Range) {
QualType ItemTy = getItemType(Item);
(void)std::initializer_list<int>{(handlers.enterField(Owner, Item), 0)...};
VisitField(Owner, Item, ItemTy, handlers...);
(void)std::initializer_list<int>{(handlers.leaveField(Owner, Item), 0)...};
template <typename ParentTy, typename... Handlers>
static void VisitRecord(CXXRecordDecl *Owner, ParentTy &Parent,
CXXRecordDecl *Wrapper, Handlers &... handlers);

template <typename... Handlers>
static void VisitRecordHelper(CXXRecordDecl *Owner,
clang::CXXRecordDecl::base_class_range Range,
Handlers &... handlers) {
for (const auto &Base : Range) {
QualType BaseTy = Base.getType();
if (Util::isSyclAccessorType(BaseTy))
(void)std::initializer_list<int>{
(handlers.handleSyclAccessorType(Base, BaseTy), 0)...};
else if (Util::isSyclStreamType(BaseTy))
(void)std::initializer_list<int>{
(handlers.handleSyclStreamType(Base, BaseTy), 0)...};
else
VisitRecord(Owner, Base, BaseTy->getAsCXXRecordDecl(), handlers...);
}
}

template <typename... Handlers>
static void VisitRecordHelper(CXXRecordDecl *Owner,
clang::RecordDecl::field_range Range,
Handlers &... handlers) {
VisitRecordFields(Owner, handlers...);
}

// Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter
// the Wrapper structure that we're currently visiting. Owner is the parent
// type (which doesn't exist in cases where it is a FieldDecl in the
// 'root'), and Wrapper is the current struct being unwrapped.
template <typename ParentTy, typename... Handlers>
static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent,
CXXRecordDecl *Wrapper,
Handlers &... handlers) {
static void VisitRecord(CXXRecordDecl *Owner, ParentTy &Parent,
CXXRecordDecl *Wrapper, Handlers &... handlers) {
(void)std::initializer_list<int>{(handlers.enterStruct(Owner, Parent), 0)...};
VisitAccessorWrapperHelper(Wrapper, Wrapper->bases(), handlers...);
VisitAccessorWrapperHelper(Wrapper, Wrapper->fields(), handlers...);
VisitRecordHelper(Wrapper, Wrapper->bases(), handlers...);
VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...);
(void)std::initializer_list<int>{(handlers.leaveStruct(Owner, Parent), 0)...};
}

int getFieldNumber(const CXXRecordDecl *BaseDecl) {
int Members = 0;
for (const auto *Field : BaseDecl->fields())
++Members;

return Members;
}

template <typename... Handlers>
static void VisitFunctorBases(CXXRecordDecl *KernelFunctor,
Handlers &... handlers) {
VisitRecordHelper(KernelFunctor, KernelFunctor->bases(), handlers...);
}


// A visitor function that dispatches to functions as defined in
// SyclKernelFieldHandler for the purposes of kernel generation.
template <typename... Handlers>
static void VisitRecordFields(RecordDecl::field_range Fields,
Handlers &... handlers) {
static void VisitRecordFields(CXXRecordDecl *Owner, Handlers &... handlers) {

for (const auto Field : Fields) {
for (const auto Field : Owner->fields()) {
(void)std::initializer_list<int>{
(handlers.enterField(nullptr, Field), 0)...};
QualType FieldTy = Field->getType();
Expand All @@ -807,12 +829,12 @@ static void VisitRecordFields(RecordDecl::field_range Fields,
else if (Util::isSyclStreamType(FieldTy)) {
// Stream actually wraps accessors, so do recursion
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
VisitAccessorWrapper(nullptr, Field, RD, handlers...);
VisitRecord(nullptr, Field, RD, handlers...);
KF_FOR_EACH(handleSyclStreamType, Field, FieldTy);
} else if (FieldTy->isStructureOrClassType()) {
if (KF_FOR_EACH(handleStructType, Field, FieldTy)) {
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
VisitAccessorWrapper(nullptr, Field, RD, handlers...);
VisitRecord(nullptr, Field, RD, handlers...);
}
} else if (FieldTy->isReferenceType())
KF_FOR_EACH(handleReferenceType, Field, FieldTy);
Expand All @@ -821,7 +843,7 @@ static void VisitRecordFields(RecordDecl::field_range Fields,
else if (FieldTy->isArrayType()) {
if (KF_FOR_EACH(handleArrayType, Field, FieldTy))
VisitArrayElements(Field, FieldTy, handlers...);
} else if (FieldTy->isScalarType())
} else if (FieldTy->isScalarType() || FieldTy->isVectorType())
KF_FOR_EACH(handleScalarType, Field, FieldTy);
else
KF_FOR_EACH(handleOtherType, Field, FieldTy);
Expand Down Expand Up @@ -1131,7 +1153,7 @@ class SyclKernelDeclCreator
}

bool handleStructType(FieldDecl *FD, QualType FieldTy) final {
addParam(FD, FieldTy);
// addParam(FD, FieldTy);
return true;
}

Expand Down Expand Up @@ -1277,7 +1299,10 @@ class SyclKernelBodyCreator
VK_LValue, SourceLocation());
}

MemberExpr *SpecialObjME = BuildMemberExpr(Base, Field);
Expr *SpecialObjME = Base;
if (Field)
SpecialObjME = BuildMemberExpr(Base, Field);

MemberExpr *MethodME = BuildMemberExpr(SpecialObjME, Method);

QualType ResultTy = Method->getReturnType();
Expand Down Expand Up @@ -1312,22 +1337,39 @@ class SyclKernelBodyCreator

bool handleSpecialType(FieldDecl *FD, QualType Ty) {
const auto *RecordDecl = Ty->getAsCXXRecordDecl();
// Perform initialization only if it is field of kernel object
if (MemberExprBases.size() == 1) {
InitializedEntity Entity =
InitializedEntity::InitializeMember(FD, &VarEntity);
// Initialize with the default constructor.
InitializationKind InitKind =
InitializationKind::CreateDefault(SourceLocation());
InitializationSequence InitSeq(SemaRef, Entity, InitKind, None);
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None);
InitExprs.push_back(MemberInit.get());
}
// TODO: VarEntity is initialized entity for KernelObjClone, I guess we need
// to create new one when enter new struct.
InitializedEntity Entity =
InitializedEntity::InitializeMember(FD, &VarEntity);
// Initialize with the default constructor.
InitializationKind InitKind =
InitializationKind::CreateDefault(SourceLocation());
InitializationSequence InitSeq(SemaRef, Entity, InitKind, None);
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None);
InitExprs.push_back(MemberInit.get());
createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName,
FD);
return true;
}

bool handleSpecialType(const CXXBaseSpecifier &BS, QualType Ty) {
const auto *RecordDecl = Ty->getAsCXXRecordDecl();
// TODO: VarEntity is initialized entity for KernelObjClone, I guess we need
// to create new one when enter new struct.
InitializedEntity Entity = InitializedEntity::InitializeBase(
SemaRef.Context, &BS, /*IsInheritedVirtualBase*/ false, &VarEntity);
// Initialize with the default constructor.
InitializationKind InitKind =
InitializationKind::CreateDefault(SourceLocation());
InitializationSequence InitSeq(SemaRef, Entity, InitKind, None);
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None);
InitExprs.push_back(MemberInit.get());

createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName,
nullptr);
return true;
}

public:
SyclKernelBodyCreator(Sema &S, SyclKernelDeclCreator &DC,
CXXRecordDecl *KernelObj,
Expand Down Expand Up @@ -1359,9 +1401,7 @@ class SyclKernelBodyCreator
}

bool handleSyclAccessorType(const CXXBaseSpecifier &BS, QualType Ty) final {
// FIXME SYCL accessor should be usable as a base type
// See https://github.com/intel/llvm/issues/28.
return true;
return handleSpecialType(BS, Ty);
}

bool handleSyclSamplerType(FieldDecl *FD, QualType Ty) final {
Expand Down Expand Up @@ -1390,7 +1430,7 @@ class SyclKernelBodyCreator
}

bool handleStructType(FieldDecl *FD, QualType FieldTy) final {
createExprForStructOrScalar(FD);
// createExprForStructOrScalar(FD);
return true;
}

Expand All @@ -1403,12 +1443,51 @@ class SyclKernelBodyCreator
MemberExprBases.push_back(BuildMemberExpr(MemberExprBases.back(), FD));
}

void leaveStruct(const CXXRecordDecl *, FieldDecl *FD) final {
void enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
CXXCastPath BasePath;
QualType DerivedTy(RD->getTypeForDecl(), 0);
QualType BaseTy = BS.getType();
SemaRef.CheckDerivedToBaseConversion(DerivedTy, BaseTy, SourceLocation(),
SourceRange(), &BasePath,
/*IgnoreBaseAccess*/ true);
auto Cast = ImplicitCastExpr::Create(
SemaRef.Context, BaseTy, CK_DerivedToBase, MemberExprBases.back(),
/* CXXCastPath=*/&BasePath, VK_LValue);
MemberExprBases.push_back(Cast);
}

void addStructInit(const CXXRecordDecl *RD){
if (!RD)
return;

int NumberOfFields = getFieldNumber(RD);
int popOut = NumberOfFields + RD->getNumBases();
llvm::SmallVector<Expr *, 16> BaseInitExprs;
for (int I = 0; I < popOut; I++) {
BaseInitExprs.push_back(InitExprs.back());
InitExprs.pop_back();
}
std::reverse(BaseInitExprs.begin(), BaseInitExprs.end());

Expr *ILE = new (SemaRef.getASTContext())
InitListExpr(SemaRef.getASTContext(), SourceLocation(), BaseInitExprs,
SourceLocation());
ILE->setType(QualType(RD->getTypeForDecl(), 0));
InitExprs.push_back(ILE);

MemberExprBases.pop_back();
}

using SyclKernelFieldHandler::enterStruct;
using SyclKernelFieldHandler::leaveStruct;
void leaveStruct(const CXXRecordDecl *, FieldDecl *FD) final {
const CXXRecordDecl *RD = FD->getType()->getAsCXXRecordDecl();
addStructInit(RD);
}

void leaveStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
const CXXRecordDecl *BaseClass = BS.getType()->getAsCXXRecordDecl();
addStructInit(BaseClass);
}

};

class SyclKernelIntHeaderCreator
Expand Down Expand Up @@ -1512,7 +1591,7 @@ class SyclKernelIntHeaderCreator
return true;
}
bool handleStructType(FieldDecl *FD, QualType FieldTy) final {
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
// addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
return true;
}
bool handleScalarType(FieldDecl *FD, QualType FieldTy) final {
Expand Down Expand Up @@ -1606,7 +1685,9 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
StableName);

ConstructingOpenCLKernel = true;
VisitRecordFields(KernelLambda->fields(), checker, kernel_decl, kernel_body,
VisitFunctorBases(KernelLambda, checker, kernel_decl, kernel_body,
int_header);
VisitRecordFields(KernelLambda, checker, kernel_decl, kernel_body,
int_header);
ConstructingOpenCLKernel = false;
}
Expand Down
38 changes: 20 additions & 18 deletions clang/test/CodeGenSYCL/integration_header.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -fsycl-int-header=%t.h %s -fsyntax-only
// RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -fsycl-int-header=%t.h %s -emit-llvm
// RUN: FileCheck -input-file=%t.h %s
//
// CHECK: #include <CL/sycl/detail/kernel_desc.hpp>
Expand Down Expand Up @@ -28,9 +28,11 @@
// CHECK-NEXT: const kernel_param_desc_t kernel_signatures[] = {
// CHECK-NEXT: //--- _ZTSZ4mainE12first_kernel
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 4 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 6112, 16 },
// CHECK-NEXT: { kernel_param_kind_t::kind_sampler, 8, 32 },
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 1, 4 },
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 12 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 6112, 24 },
// CHECK-NEXT: { kernel_param_kind_t::kind_sampler, 8, 40 },
// CHECK-EMPTY:
// CHECK-NEXT: //--- _ZTSN16second_namespace13second_kernelIcEE
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 },
Expand All @@ -46,12 +48,15 @@
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 6112, 4 },
// CHECK-EMPTY:
// CHECK-NEXT: //--- _ZTSZ4mainE16accessor_in_base
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 64, 0 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 8 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 24 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 40 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 52 },
// CHECK-NEXT: //--- _ZTSZ4mainE16accessor_in_base
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 },
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 4 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 8 },
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 20 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 24 },
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 36 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 40 },
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 52 },
// CHECK-EMPTY:
// CHECK-NEXT: };
//
Expand Down Expand Up @@ -116,15 +121,13 @@ int main() {
acc2;
int i = 13;
cl::sycl::sampler smplr;
// TODO: Uncomemnt when structures in kernel arguments are correctly processed
// by SYCL compiler
/* struct {
struct {
char c;
int i;
} test_s;
test_s.c = 14;*/
test_s.c = 14;
kernel_single_task<class first_kernel>([=]() {
if (i == 13 /*&& test_s.c == 14*/) {
if (i == 13 && test_s.c == 14) {

acc1.use();
acc2.use();
Expand All @@ -151,10 +154,9 @@ int main() {
}
});

// FIXME: We cannot use the member-capture because all the handlers except the
// integration header handler in SemaSYCL don't handle base types right.
accessor_in_base::captured c;
kernel_single_task<class accessor_in_base>([c]() {
kernel_single_task<class accessor_in_base>([=]() {
c.use();
});

return 0;
Expand Down