Skip to content

Commit 33ae7ec

Browse files
committed
[SYCL] Do not decompose non-trivial classes with pointers
Instead, the following code is being generated: ``` void ocl_kernel(__generated_type GT) { Kernel KernelObjClone { *(reinterpret_cast<UsersType*>(&GT)) }; } ```
1 parent 3916d3b commit 33ae7ec

File tree

9 files changed

+113
-190
lines changed

9 files changed

+113
-190
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,20 +1796,10 @@ class SyclKernelDecompMarker : public SyclKernelFieldHandler {
17961796
CollectionStack.back() = true;
17971797
PointerStack.pop_back();
17981798
} else if (PointerStack.pop_back_val()) {
1799-
// FIXME: Stop triggering decomposition for non-trivial types with
1800-
// pointers
1801-
if (RD->isTrivial()) {
1802-
PointerStack.back() = true;
1803-
if (!RD->hasAttr<SYCLGenerateNewTypeAttr>())
1804-
RD->addAttr(
1805-
SYCLGenerateNewTypeAttr::CreateImplicit(SemaRef.getASTContext()));
1806-
} else {
1807-
// We are visiting a non-trivial type with pointer.
1808-
CollectionStack.back() = true;
1809-
if (!RD->hasAttr<SYCLRequiresDecompositionAttr>())
1810-
RD->addAttr(SYCLRequiresDecompositionAttr::CreateImplicit(
1811-
SemaRef.getASTContext()));
1812-
}
1799+
PointerStack.back() = true;
1800+
if (!RD->hasAttr<SYCLGenerateNewTypeAttr>())
1801+
RD->addAttr(
1802+
SYCLGenerateNewTypeAttr::CreateImplicit(SemaRef.getASTContext()));
18131803
}
18141804
return true;
18151805
}
@@ -2916,6 +2906,18 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
29162906
Init.get());
29172907
}
29182908

2909+
void addBaseInit(const CXXBaseSpecifier &BS, QualType Ty,
2910+
InitializationKind InitKind, MultiExprArg Args) {
2911+
InitializedEntity Entity = InitializedEntity::InitializeBase(
2912+
SemaRef.Context, &BS, /*IsInheritedVirtualBase*/ false, &VarEntity);
2913+
InitializationSequence InitSeq(SemaRef, Entity, InitKind, Args);
2914+
ExprResult Init = InitSeq.Perform(SemaRef, Entity, InitKind, Args);
2915+
2916+
InitListExpr *ParentILE = CollectionInitExprs.back();
2917+
ParentILE->updateInit(SemaRef.getASTContext(), ParentILE->getNumInits(),
2918+
Init.get());
2919+
}
2920+
29192921
void addSimpleBaseInit(const CXXBaseSpecifier &BS, QualType Ty) {
29202922
InitializationKind InitKind =
29212923
InitializationKind::CreateCopy(KernelCallerSrcLoc, KernelCallerSrcLoc);
@@ -2961,6 +2963,13 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
29612963
false, SemaRef.CurFPFeatureOverrides());
29622964
}
29632965

2966+
Expr *createDerefOp(Expr *E) {
2967+
return UnaryOperator::Create(SemaRef.Context, E, UO_Deref,
2968+
E->getType()->getPointeeType(),
2969+
VK_LValue, OK_Ordinary, KernelCallerSrcLoc,
2970+
false, SemaRef.CurFPFeatureOverrides());
2971+
}
2972+
29642973
Expr *buildMemCpyCall(Expr *From, Expr *To, QualType T) {
29652974
// Compute the size of the memory buffer to be copied.
29662975
QualType SizeType = SemaRef.Context.getSizeType();
@@ -2992,31 +3001,40 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
29923001
return Call.getAs<Expr>();
29933002
}
29943003

2995-
// Adds default initializer for generated type and creates
2996-
// a call to __builtin_memcpy to initialize local clone from
2997-
// kernel argument.
29983004
void handleGeneratedType(FieldDecl *FD, QualType Ty) {
2999-
addFieldInit(FD, Ty, None,
3000-
InitializationKind::CreateDefault(KernelCallerSrcLoc));
3001-
addFieldMemberExpr(FD, Ty);
3002-
Expr *ParamRef = createGetAddressOf(createParamReferenceExpr());
3003-
Expr *LocalCloneRef = createGetAddressOf(MemberExprBases.back());
3004-
Expr *MemCpyCallExpr = buildMemCpyCall(ParamRef, LocalCloneRef, Ty);
3005-
BodyStmts.push_back(MemCpyCallExpr);
3006-
removeFieldMemberExpr(FD, Ty);
3005+
// Equivalent of the following code is generated here:
3006+
// void ocl_kernel(__generated_type GT) {
3007+
// Kernel KernelObjClone { *(reinterpret_cast<UsersType*>(&GT)) };
3008+
// }
3009+
3010+
Expr *ParamRef = createParamReferenceExpr();
3011+
Expr *ParamAddress = createGetAddressOf(ParamRef);
3012+
3013+
QualType ResultType = SemaRef.Context.getPointerType(Ty);
3014+
TypeSourceInfo *TSI = SemaRef.Context.CreateTypeSourceInfo(ResultType);
3015+
CXXReinterpretCastExpr *RCE = CXXReinterpretCastExpr::Create(
3016+
SemaRef.Context, ResultType, VK_PRValue, CK_BitCast, ParamAddress,
3017+
/*Path=*/nullptr, TSI, SourceLocation(), SourceLocation(),
3018+
SourceRange());
3019+
Expr *Initializer = createDerefOp(RCE);
3020+
addFieldInit(FD, Ty, Initializer);
30073021
}
30083022

3009-
// Adds default initializer for generated base and creates
3010-
// a call to __builtin_memcpy to initialize the base of local clone
3011-
// from kernel argument.
30123023
void handleGeneratedType(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS,
30133024
QualType Ty) {
3014-
addBaseInit(BS, Ty, InitializationKind::CreateDefault(KernelCallerSrcLoc));
3015-
Expr *ParamRef = createGetAddressOf(createParamReferenceExpr());
3016-
Expr *LocalCloneRef = createGetAddressOf(MemberExprBases.back());
3017-
LocalCloneRef = addDerivedToBaseCastExpr(RD, BS, LocalCloneRef);
3018-
Expr *MemCpyCallExpr = buildMemCpyCall(ParamRef, LocalCloneRef, Ty);
3019-
BodyStmts.push_back(MemCpyCallExpr);
3025+
Expr *ParamRef = createParamReferenceExpr();
3026+
Expr *ParamAddress = createGetAddressOf(ParamRef);
3027+
3028+
QualType ResultType = SemaRef.Context.getPointerType(Ty);
3029+
TypeSourceInfo *TSI = SemaRef.Context.CreateTypeSourceInfo(ResultType);
3030+
CXXReinterpretCastExpr *RCE = CXXReinterpretCastExpr::Create(
3031+
SemaRef.Context, ResultType, VK_PRValue, CK_BitCast, ParamAddress,
3032+
/*Path=*/nullptr, TSI, SourceLocation(), SourceLocation(),
3033+
SourceRange());
3034+
Expr *Initializer = createDerefOp(RCE);
3035+
InitializationKind InitKind =
3036+
InitializationKind::CreateCopy(KernelCallerSrcLoc, KernelCallerSrcLoc);
3037+
addBaseInit(BS, Ty, InitKind, Initializer);
30203038
}
30213039

30223040
MemberExpr *buildMemberExpr(Expr *Base, ValueDecl *Member) {

clang/test/CodeGenSYCL/inheritance.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,14 @@ int main() {
6464
// Initialize 'base' subobject
6565
// CHECK: call void @llvm.memcpy.p4.p4.i64(ptr addrspace(4) align 8 %[[LOCAL_OBJECT]], ptr addrspace(4) align 4 %[[ARG_BASE]], i64 12, i1 false)
6666

67-
// Initialize field 'a'
68-
// CHECK: %[[GEP_A:[a-zA-Z0-9]+]] = getelementptr inbounds %struct.derived, ptr addrspace(4) %[[LOCAL_OBJECT]], i32 0, i32 3
69-
// CHECK: %[[LOAD_A:[0-9]+]] = load i32, ptr addrspace(4) %[[ARG_A]], align 4
70-
// CHECK: store i32 %[[LOAD_A]], ptr addrspace(4) %[[GEP_A]]
71-
7267
// Initialize 'second_base' subobject
7368
// First, derived-to-base cast with offset:
7469
// CHECK: %[[OFFSET_CALC:.*]] = getelementptr inbounds i8, ptr addrspace(4) %[[LOCAL_OBJECT]], i64 16
7570
// Initialize 'second_base'
7671
// CHECK: call void @llvm.memcpy.p4.p4.i64(ptr addrspace(4) align 8 %[[OFFSET_CALC]], ptr addrspace(4) align 8 %[[ARG_BASE1]], i64 8, i1 false)
72+
73+
// Initialize field 'a'
74+
// CHECK: %[[GEP_A:[a-zA-Z0-9]+]] = getelementptr inbounds %struct.derived, ptr addrspace(4) %[[LOCAL_OBJECT]], i32 0, i32 3
75+
// CHECK: %[[LOAD_A:[0-9]+]] = load i32, ptr addrspace(4) %[[ARG_A]], align 4
76+
// CHECK: store i32 %[[LOAD_A]], ptr addrspace(4) %[[GEP_A]]
77+

clang/test/CodeGenSYCL/no_opaque_inheritance.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,18 @@ int main() {
6767
// CHECK: %[[PARAM_TO_PTR:.*]] = bitcast %struct.base addrspace(4)* %[[ARG_BASE]] to i8 addrspace(4)*
6868
// CHECK: call void @llvm.memcpy.p4i8.p4i8.i64(i8 addrspace(4)* align 8 %[[BASE_TO_PTR]], i8 addrspace(4)* align 4 %[[PARAM_TO_PTR]], i64 12, i1 false)
6969

70-
// Initialize field 'a'
71-
// CHECK: %[[GEP_A:[a-zA-Z0-9]+]] = getelementptr inbounds %struct.derived, %struct.derived addrspace(4)* %[[LOCAL_OBJECT]], i32 0, i32 3
72-
// CHECK: %[[LOAD_A:[0-9]+]] = load i32, i32 addrspace(4)* %[[ARG_A]], align 4
73-
// CHECK: store i32 %[[LOAD_A]], i32 addrspace(4)* %[[GEP_A]]
74-
7570
// Initialize 'second_base' subobject
7671
// First, derived-to-base cast with offset:
7772
// CHECK: %[[DERIVED_PTR:.*]] = bitcast %struct.derived addrspace(4)* %[[LOCAL_OBJECT]] to i8 addrspace(4)*
7873
// CHECK: %[[OFFSET_CALC:.*]] = getelementptr inbounds i8, i8 addrspace(4)* %[[DERIVED_PTR]], i64 16
7974
// CHECK: %[[TO_SECOND_BASE:.*]] = bitcast i8 addrspace(4)* %[[OFFSET_CALC]] to %class.second_base addrspace(4)*
80-
// CHECK: %[[SECOND_BASE_TO_PTR:.*]] = bitcast %class.second_base addrspace(4)* %[[TO_SECOND_BASE]] to i8 addrspace(4)*
81-
// CHECK: %[[SECOND_PARAM_TO_PTR:.*]] = bitcast %class.__generated_second_base addrspace(4)* %[[ARG_BASE1]] to i8 addrspace(4)*
82-
// CHECK: call void @llvm.memcpy.p4i8.p4i8.i64(i8 addrspace(4)* align 8 %[[SECOND_BASE_TO_PTR]], i8 addrspace(4)* align 8 %[[SECOND_PARAM_TO_PTR]], i64 8, i1 false)
75+
// CHECK: %[[GEN_TO_SECOND_BASE:.*]] = bitcast %class.__generated_second_base addrspace(4)* %[[ARG_BASE1]] to %class.second_base addrspace(4)*
76+
// CHECK: %[[TO:.*]] = bitcast %class.second_base addrspace(4)* %[[TO_SECOND_BASE]] to i8 addrspace(4)*
77+
// CHECK: %[[FROM:.*]] = bitcast %class.second_base addrspace(4)* %[[GEN_TO_SECOND_BASE]] to i8 addrspace(4)*
78+
// CHECK: call void @llvm.memcpy.p4i8.p4i8.i64(i8 addrspace(4)* align 8 %[[TO]], i8 addrspace(4)* align 8 %[[FROM]], i64 8, i1 false)
8379

80+
81+
// Initialize field 'a'
82+
// CHECK: %[[GEP_A:[a-zA-Z0-9]+]] = getelementptr inbounds %struct.derived, %struct.derived addrspace(4)* %[[LOCAL_OBJECT]], i32 0, i32 3
83+
// CHECK: %[[LOAD_A:[0-9]+]] = load i32, i32 addrspace(4)* %[[ARG_A]], align 4
84+
// CHECK: store i32 %[[LOAD_A]], i32 addrspace(4)* %[[GEP_A]]

clang/test/CodeGenSYCL/no_opaque_pointers-in-structs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ int main() {
4545
// CHECK-SAME: %[[GENERATED_A]]* noundef byval(%[[GENERATED_A]]) align 8 %_arg_F3,
4646
// CHECK-SAME: %[[WRAPPER_F4_1]]* noundef byval(%[[WRAPPER_F4_1]]) align 8 %_arg_F4
4747
// CHECK-SAME: %[[WRAPPER_F4_2]]* noundef byval(%[[WRAPPER_F4_2]]) align 8 %_arg_F41
48-
// CHECK: define {{.*}}spir_kernel void @{{.*}}lambdas{{.*}}(%[[WRAPPER_LAMBDA_PTR]]* noundef byval(%[[WRAPPER_LAMBDA_PTR]]) align 8 %_arg_Ptr)
48+
// CHECK: define {{.*}}spir_kernel void @{{.*}}lambdas{{.*}}(%[[WRAPPER_LAMBDA_PTR]]* noundef byval(%[[WRAPPER_LAMBDA_PTR]]) align 8 %_arg_Lambda)

clang/test/CodeGenSYCL/pointers-in-structs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ int main() {
4545
// CHECK-SAME: ptr noundef byval(%[[GENERATED_A]]) align 8 %_arg_F3,
4646
// CHECK-SAME: ptr noundef byval(%[[WRAPPER_F4_1]]) align 8 %_arg_F4
4747
// CHECK-SAME: ptr noundef byval(%[[WRAPPER_F4_2]]) align 8 %_arg_F41
48-
// CHECK: define {{.*}}spir_kernel void @{{.*}}lambdas{{.*}}(ptr noundef byval(%[[WRAPPER_LAMBDA_PTR]]) align 8 %_arg_Ptr)
48+
// CHECK: define {{.*}}spir_kernel void @{{.*}}lambdas{{.*}}(ptr noundef byval(%[[WRAPPER_LAMBDA_PTR]]) align 8 %_arg_Lambda)

0 commit comments

Comments
 (0)