Skip to content

Commit 5adfd79

Browse files
authored
[SYCL] Add support for union types as kernel parameter (#2285)
This patch adds support for union types kernel arguments. - Add 'VisitUnion' function. - Add separate handler to implement a diagnostic on attempt to pass union with accessor/sampler/stream member as a kernel argument. - Add CodeGen/Sema/Intengreation header/runtime tests. Signed-off-by: Soumi Manna <[email protected]>
1 parent 63ac3d3 commit 5adfd79

File tree

9 files changed

+491
-6
lines changed

9 files changed

+491
-6
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9736,6 +9736,8 @@ def warn_opencl_generic_address_space_arg : Warning<
97369736
"passing non-generic address space pointer to %0"
97379737
" may cause dynamic conversion affecting performance">,
97389738
InGroup<Conversion>, DefaultIgnore;
9739+
def err_bad_union_kernel_param_members : Error<
9740+
"%0 cannot be used inside a union kernel parameter">;
97399741

97409742
// OpenCL v2.0 s6.13.6 -- Builtin Pipe Functions
97419743
def err_opencl_builtin_pipe_first_arg : Error<

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,9 @@ class KernelObjVisitor {
831831
else if (ElementTy->isStructureOrClassType())
832832
VisitRecord(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
833833
handlers...);
834+
else if (ElementTy->isUnionType())
835+
VisitUnion(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
836+
handlers...);
834837
else if (ElementTy->isArrayType())
835838
VisitArrayElements(ArrayField, ElementTy, handlers...);
836839
else if (ElementTy->isScalarType())
@@ -858,6 +861,65 @@ class KernelObjVisitor {
858861
void VisitRecord(const CXXRecordDecl *Owner, ParentTy &Parent,
859862
const CXXRecordDecl *Wrapper, Handlers &... handlers);
860863

864+
// Base case, only calls these when filtered.
865+
template <typename... FilteredHandlers, typename ParentTy>
866+
std::enable_if_t<(sizeof...(FilteredHandlers) > 0)>
867+
VisitUnion(FilteredHandlers &... handlers, const CXXRecordDecl *Owner,
868+
ParentTy &Parent, const CXXRecordDecl *Wrapper) {
869+
(void)std::initializer_list<int>{
870+
(handlers.enterUnion(Owner, Parent), 0)...};
871+
VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...);
872+
(void)std::initializer_list<int>{
873+
(handlers.leaveUnion(Owner, Parent), 0)...};
874+
}
875+
876+
// Handle empty base case.
877+
template <typename ParentTy>
878+
void VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
879+
const CXXRecordDecl *Wrapper) {}
880+
881+
template <typename... FilteredHandlers, typename ParentTy,
882+
typename CurHandler, typename... Handlers>
883+
std::enable_if_t<!CurHandler::VisitUnionBody &&
884+
(sizeof...(FilteredHandlers) > 0)>
885+
VisitUnion(FilteredHandlers &... filtered_handlers,
886+
const CXXRecordDecl *Owner, ParentTy &Parent,
887+
const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
888+
Handlers &... handlers) {
889+
VisitUnion<FilteredHandlers...>(filtered_handlers..., Owner, Parent,
890+
Wrapper, handlers...);
891+
}
892+
893+
template <typename... FilteredHandlers, typename ParentTy,
894+
typename CurHandler, typename... Handlers>
895+
std::enable_if_t<CurHandler::VisitUnionBody &&
896+
(sizeof...(FilteredHandlers) > 0)>
897+
VisitUnion(FilteredHandlers &... filtered_handlers,
898+
const CXXRecordDecl *Owner, ParentTy &Parent,
899+
const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
900+
Handlers &... handlers) {
901+
VisitUnion<FilteredHandlers..., CurHandler>(
902+
filtered_handlers..., cur_handler, Owner, Parent, Wrapper, handlers...);
903+
}
904+
905+
// Add overloads without having filtered-handlers
906+
// to handle leading-empty argument packs.
907+
template <typename ParentTy, typename CurHandler, typename... Handlers>
908+
std::enable_if_t<!CurHandler::VisitUnionBody>
909+
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
910+
const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
911+
Handlers &... handlers) {
912+
VisitUnion(Owner, Parent, Wrapper, handlers...);
913+
}
914+
915+
template <typename ParentTy, typename CurHandler, typename... Handlers>
916+
std::enable_if_t<CurHandler::VisitUnionBody>
917+
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
918+
const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
919+
Handlers &... handlers) {
920+
VisitUnion<CurHandler>(cur_handler, Owner, Parent, Wrapper, handlers...);
921+
}
922+
861923
template <typename... Handlers>
862924
void VisitRecordHelper(const CXXRecordDecl *Owner,
863925
clang::CXXRecordDecl::base_class_const_range Range,
@@ -943,6 +1005,11 @@ class KernelObjVisitor {
9431005
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
9441006
VisitRecord(Owner, Field, RD, handlers...);
9451007
}
1008+
} else if (FieldTy->isUnionType()) {
1009+
if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
1010+
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
1011+
VisitUnion(Owner, Field, RD, handlers...);
1012+
}
9461013
} else if (FieldTy->isReferenceType())
9471014
KF_FOR_EACH(handleReferenceType, Field, FieldTy);
9481015
else if (FieldTy->isPointerType())
@@ -982,6 +1049,7 @@ class SyclKernelFieldHandler {
9821049
SyclKernelFieldHandler(Sema &S) : SemaRef(S) {}
9831050

9841051
public:
1052+
static constexpr const bool VisitUnionBody = false;
9851053
// Mark these virtual so that we can use override in the implementer classes,
9861054
// despite virtual dispatch never being used.
9871055

@@ -1006,6 +1074,7 @@ class SyclKernelFieldHandler {
10061074
}
10071075
virtual bool handleSyclHalfType(FieldDecl *, QualType) { return true; }
10081076
virtual bool handleStructType(FieldDecl *, QualType) { return true; }
1077+
virtual bool handleUnionType(FieldDecl *, QualType) { return true; }
10091078
virtual bool handleReferenceType(FieldDecl *, QualType) { return true; }
10101079
virtual bool handlePointerType(FieldDecl *, QualType) { return true; }
10111080
virtual bool handleArrayType(FieldDecl *, QualType) { return true; }
@@ -1025,6 +1094,8 @@ class SyclKernelFieldHandler {
10251094
virtual bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) {
10261095
return true;
10271096
}
1097+
virtual bool enterUnion(const CXXRecordDecl *, FieldDecl *) { return true; }
1098+
virtual bool leaveUnion(const CXXRecordDecl *, FieldDecl *) { return true; }
10281099

10291100
// The following are used for stepping through array elements.
10301101

@@ -1047,7 +1118,6 @@ class SyclKernelFieldHandler {
10471118
class SyclKernelFieldChecker : public SyclKernelFieldHandler {
10481119
bool IsInvalid = false;
10491120
DiagnosticsEngine &Diag;
1050-
10511121
// Check whether the object should be disallowed from being copied to kernel.
10521122
// Return true if not copyable, false if copyable.
10531123
bool checkNotCopyableToKernel(const FieldDecl *FD, const QualType &FieldTy) {
@@ -1202,6 +1272,65 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
12021272
}
12031273
};
12041274

1275+
// A type to check the validity of accessing accessor/sampler/stream
1276+
// types as kernel parameters inside union.
1277+
class SyclKernelUnionChecker : public SyclKernelFieldHandler {
1278+
int UnionCount = 0;
1279+
bool IsInvalid = false;
1280+
DiagnosticsEngine &Diag;
1281+
1282+
public:
1283+
SyclKernelUnionChecker(Sema &S)
1284+
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
1285+
bool isValid() { return !IsInvalid; }
1286+
static constexpr const bool VisitUnionBody = true;
1287+
1288+
bool checkType(SourceLocation Loc, QualType Ty) {
1289+
if (UnionCount) {
1290+
IsInvalid = true;
1291+
Diag.Report(Loc, diag::err_bad_union_kernel_param_members) << Ty;
1292+
}
1293+
return isValid();
1294+
}
1295+
1296+
bool enterUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
1297+
++UnionCount;
1298+
return true;
1299+
}
1300+
1301+
bool leaveUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
1302+
--UnionCount;
1303+
return true;
1304+
}
1305+
1306+
bool handleSyclAccessorType(FieldDecl *FD, QualType FieldTy) final {
1307+
return checkType(FD->getLocation(), FieldTy);
1308+
}
1309+
1310+
bool handleSyclAccessorType(const CXXBaseSpecifier &BS,
1311+
QualType FieldTy) final {
1312+
return checkType(BS.getBeginLoc(), FieldTy);
1313+
}
1314+
1315+
bool handleSyclSamplerType(FieldDecl *FD, QualType FieldTy) final {
1316+
return checkType(FD->getLocation(), FieldTy);
1317+
}
1318+
1319+
bool handleSyclSamplerType(const CXXBaseSpecifier &BS,
1320+
QualType FieldTy) final {
1321+
return checkType(BS.getBeginLoc(), FieldTy);
1322+
}
1323+
1324+
bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final {
1325+
return checkType(FD->getLocation(), FieldTy);
1326+
}
1327+
1328+
bool handleSyclStreamType(const CXXBaseSpecifier &BS,
1329+
QualType FieldTy) final {
1330+
return checkType(BS.getBeginLoc(), FieldTy);
1331+
}
1332+
};
1333+
12051334
// A type to Create and own the FunctionDecl for the kernel.
12061335
class SyclKernelDeclCreator : public SyclKernelFieldHandler {
12071336
FunctionDecl *KernelDecl;
@@ -1453,6 +1582,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
14531582
return true;
14541583
}
14551584

1585+
bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
1586+
return handleScalarType(FD, FieldTy);
1587+
}
1588+
14561589
bool handleSyclHalfType(FieldDecl *FD, QualType FieldTy) final {
14571590
addParam(FD, FieldTy);
14581591
return true;
@@ -1805,6 +1938,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
18051938
return true;
18061939
}
18071940

1941+
bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
1942+
return handleScalarType(FD, FieldTy);
1943+
}
1944+
18081945
bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
18091946
CXXCastPath BasePath;
18101947
QualType DerivedTy(RD->getTypeForDecl(), 0);
@@ -2012,6 +2149,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
20122149
return true;
20132150
}
20142151

2152+
bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
2153+
return handleScalarType(FD, FieldTy);
2154+
}
2155+
20152156
bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final {
20162157
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
20172158
return true;
@@ -2105,14 +2246,14 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
21052246
}
21062247
}
21072248

2108-
SyclKernelFieldChecker Checker(*this);
2109-
2249+
SyclKernelFieldChecker FieldChecker(*this);
2250+
SyclKernelUnionChecker UnionChecker(*this);
21102251
KernelObjVisitor Visitor{*this};
21112252
DiagnosingSYCLKernel = true;
2112-
Visitor.VisitRecordBases(KernelObj, Checker);
2113-
Visitor.VisitRecordFields(KernelObj, Checker);
2253+
Visitor.VisitRecordBases(KernelObj, FieldChecker, UnionChecker);
2254+
Visitor.VisitRecordFields(KernelObj, FieldChecker, UnionChecker);
21142255
DiagnosingSYCLKernel = false;
2115-
if (!Checker.isValid())
2256+
if (!FieldChecker.isValid() || !UnionChecker.isValid())
21162257
KernelFunc->setInvalidDecl();
21172258
}
21182259

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -fsycl-int-header=%t.h %s -o %t.out
2+
// RUN: FileCheck -input-file=%t.h %s
3+
4+
// This test checks the integration header generated when
5+
// the kernel argument is union.
6+
7+
// CHECK: #include <CL/sycl/detail/kernel_desc.hpp>
8+
9+
// CHECK: class kernel_A;
10+
11+
// CHECK: __SYCL_INLINE_NAMESPACE(cl) {
12+
// CHECK-NEXT: namespace sycl {
13+
// CHECK-NEXT: namespace detail {
14+
15+
// CHECK: static constexpr
16+
// CHECK-NEXT: const char* const kernel_names[] = {
17+
// CHECK-NEXT: "_ZTSZ4mainE8kernel_A"
18+
// CHECK-NEXT: };
19+
20+
// CHECK: static constexpr
21+
// CHECK-NEXT: const kernel_param_desc_t kernel_signatures[] = {
22+
// CHECK-NEXT: //--- _ZTSZ4mainE8kernel_A
23+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 12, 0 },
24+
// CHECK-EMPTY:
25+
// CHECK-NEXT:};
26+
27+
// CHECK: static constexpr
28+
// CHECK-NEXT: const unsigned kernel_signature_start[] = {
29+
// CHECK-NEXT: 0 // _ZTSZ4mainE8kernel_A
30+
// CHECK-NEXT: };
31+
32+
// CHECK: template <> struct KernelInfo<class kernel_A> {
33+
34+
union MyUnion {
35+
int FldInt;
36+
char FldChar;
37+
float FldArr[3];
38+
};
39+
40+
template <typename name, typename Func>
41+
__attribute__((sycl_kernel)) void a_kernel(Func kernelFunc) {
42+
kernelFunc();
43+
}
44+
45+
int main() {
46+
47+
MyUnion obj;
48+
49+
a_kernel<class kernel_A>(
50+
[=]() {
51+
float local = obj.FldArr[2];
52+
});
53+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: %clang_cc1 -fsycl -fsycl-is-device -I %S/Inputs -triple spir64-unknown-unknown-sycldevice -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s
2+
3+
// This test checks a kernel argument that is union with both array and non-array fields.
4+
5+
#include "sycl.hpp"
6+
7+
using namespace cl::sycl;
8+
9+
union MyUnion {
10+
int FldInt;
11+
char FldChar;
12+
float FldArr[3];
13+
};
14+
15+
template <typename name, typename Func>
16+
__attribute__((sycl_kernel)) void a_kernel(Func kernelFunc) {
17+
kernelFunc();
18+
}
19+
20+
int main() {
21+
22+
MyUnion obj;
23+
24+
a_kernel<class kernel_A>(
25+
[=]() {
26+
float local = obj.FldArr[2];
27+
});
28+
}
29+
30+
// CHECK kernel_A parameters
31+
// CHECK: define spir_kernel void @{{.*}}kernel_A(%union.{{.*}}.MyUnion* byval(%union.{{.*}}.MyUnion) align 4 [[MEM_ARG:%[a-zA-Z0-9_]+]])
32+
33+
// Check lambda object alloca
34+
// CHECK: [[LOCAL_OBJECT:%0]] = alloca %"class.{{.*}}.anon", align 4
35+
36+
// CHECK: [[L_STRUCT_ADDR:%[a-zA-Z0-9_]+]] = getelementptr inbounds %"class.{{.*}}.anon", %"class.{{.*}}.anon"* [[LOCAL_OBJECT]], i32 0, i32 0
37+
// CHECK: [[MEMCPY_DST:%[0-9a-zA-Z_]+]] = bitcast %union.{{.*}}MyUnion* [[L_STRUCT_ADDR]] to i8*
38+
// CHECK: [[MEMCPY_SRC:%[0-9a-zA-Z_]+]] = bitcast %union.{{.*}}MyUnion* [[MEM_ARG]] to i8*
39+
// CHECK: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 [[MEMCPY_DST]], i8* align 4 [[MEMCPY_SRC]], i64 12, i1 false)
40+
// CHECK: [[ACC_CAST1:%[0-9]+]] = addrspacecast %"class.{{.*}}.anon"* [[LOCAL_OBJECT]] to %"class.{{.*}}.anon" addrspace(4)*
41+
// CHECK: call spir_func void @{{.*}}(%"class.{{.*}}.anon" addrspace(4)* [[ACC_CAST1]])
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -verify -fsyntax-only %s
2+
3+
// This test checks if compiler reports compilation error on an attempt to pass
4+
// accessor/sampler as SYCL kernel parameter inside union.
5+
6+
#include "sycl.hpp"
7+
using namespace cl::sycl;
8+
9+
union union_with_sampler {
10+
cl::sycl::sampler smpl;
11+
// expected-error@-1 {{'cl::sycl::sampler' cannot be used inside a union kernel parameter}}
12+
};
13+
14+
template <typename name, typename Func>
15+
__attribute__((sycl_kernel)) void a_kernel(Func kernelFunc) {
16+
kernelFunc();
17+
}
18+
19+
int main() {
20+
21+
using Accessor =
22+
accessor<int, 1, access::mode::read_write, access::target::global_buffer>;
23+
24+
union union_with_accessor {
25+
Accessor member_acc[1];
26+
// expected-error@-1 {{'Accessor' (aka 'accessor<int, 1, access::mode::read_write, access::target::global_buffer>') cannot be used inside a union kernel parameter}}
27+
} union_acc;
28+
29+
union_with_sampler Sampler;
30+
31+
a_kernel<class kernel_A>(
32+
[=]() {
33+
Sampler.smpl.use();
34+
});
35+
36+
a_kernel<class kernel_B>(
37+
[=]() {
38+
union_acc.member_acc[1].use();
39+
});
40+
41+
return 0;
42+
}

0 commit comments

Comments
 (0)