Skip to content

[SYCL] Add support for union types as kernel parameter #2285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7be969a
[SYCL] Add support for union
smanna12 Aug 7, 2020
6de6dff
add tests
smanna12 Aug 11, 2020
60f02ac
Fix Clang-format issue
smanna12 Aug 11, 2020
fc3999b
Fix Clang-format issue
smanna12 Aug 11, 2020
e71672c
Fix Clang-format issue
smanna12 Aug 11, 2020
c36ab3f
Merge remote-tracking branch 'intel_llvm/sycl' into UnionKernelArgument
smanna12 Aug 12, 2020
99e8b2a
update tests and code changes
smanna12 Aug 13, 2020
2e7b74b
Fix Clang format issue
smanna12 Aug 13, 2020
3e4d4fc
Fix Clang format issue
smanna12 Aug 13, 2020
d3a5172
Fix test
smanna12 Aug 13, 2020
7802eda
Update tests and patch based on review comments
smanna12 Aug 13, 2020
2ed64f3
Fix Clang-format issue
smanna12 Aug 13, 2020
e9c65c0
Fix runtime test failure and add new integration header test
smanna12 Aug 13, 2020
1376cad
Add diagnostic tests
smanna12 Aug 13, 2020
bc09151
Fix clang format issue
smanna12 Aug 13, 2020
5931e4a
Address review commensts
smanna12 Aug 14, 2020
fabd978
Address review comment and fix clang-format issues
smanna12 Aug 14, 2020
0e15676
Fix runtime test
smanna12 Aug 14, 2020
ede7a0b
Fix typo on runtime test
smanna12 Aug 14, 2020
7338c0d
Fix runtime test
smanna12 Aug 14, 2020
6a80b99
Fix runtime test
smanna12 Aug 14, 2020
fdfbe19
Add empty base case for windows failure
smanna12 Aug 14, 2020
03460a1
Fix sema codes
smanna12 Aug 14, 2020
f4450a6
Fix clang-format issue
smanna12 Aug 14, 2020
f772189
Fix clang-format issue and update source codes
smanna12 Aug 14, 2020
db9f49d
Fix clang-format issues
smanna12 Aug 14, 2020
347b21c
update test based on reiew
smanna12 Aug 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -9736,6 +9736,8 @@ def warn_opencl_generic_address_space_arg : Warning<
"passing non-generic address space pointer to %0"
" may cause dynamic conversion affecting performance">,
InGroup<Conversion>, DefaultIgnore;
def err_bad_union_kernel_param_members : Error<
"%0 cannot be used inside a union kernel parameter">;

// OpenCL v2.0 s6.13.6 -- Builtin Pipe Functions
def err_opencl_builtin_pipe_first_arg : Error<
Expand Down
153 changes: 147 additions & 6 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,9 @@ class KernelObjVisitor {
else if (ElementTy->isStructureOrClassType())
VisitRecord(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
handlers...);
else if (ElementTy->isUnionType())
VisitUnion(Owner, ArrayField, ElementTy->getAsCXXRecordDecl(),
handlers...);
else if (ElementTy->isArrayType())
VisitArrayElements(ArrayField, ElementTy, handlers...);
else if (ElementTy->isScalarType())
Expand Down Expand Up @@ -858,6 +861,65 @@ class KernelObjVisitor {
void VisitRecord(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper, Handlers &... handlers);

// Base case, only calls these when filtered.
template <typename... FilteredHandlers, typename ParentTy>
std::enable_if_t<(sizeof...(FilteredHandlers) > 0)>
VisitUnion(FilteredHandlers &... handlers, const CXXRecordDecl *Owner,
ParentTy &Parent, const CXXRecordDecl *Wrapper) {
(void)std::initializer_list<int>{
(handlers.enterUnion(Owner, Parent), 0)...};
VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...);
(void)std::initializer_list<int>{
(handlers.leaveUnion(Owner, Parent), 0)...};
}

// Handle empty base case.
template <typename ParentTy>
void VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper) {}

template <typename... FilteredHandlers, typename ParentTy,
typename CurHandler, typename... Handlers>
std::enable_if_t<!CurHandler::VisitUnionBody &&
(sizeof...(FilteredHandlers) > 0)>
VisitUnion(FilteredHandlers &... filtered_handlers,
const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
Handlers &... handlers) {
VisitUnion<FilteredHandlers...>(filtered_handlers..., Owner, Parent,
Wrapper, handlers...);
}

template <typename... FilteredHandlers, typename ParentTy,
typename CurHandler, typename... Handlers>
std::enable_if_t<CurHandler::VisitUnionBody &&
(sizeof...(FilteredHandlers) > 0)>
VisitUnion(FilteredHandlers &... filtered_handlers,
const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
Handlers &... handlers) {
VisitUnion<FilteredHandlers..., CurHandler>(
filtered_handlers..., cur_handler, Owner, Parent, Wrapper, handlers...);
}

// Add overloads without having filtered-handlers
// to handle leading-empty argument packs.
template <typename ParentTy, typename CurHandler, typename... Handlers>
std::enable_if_t<!CurHandler::VisitUnionBody>
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
Handlers &... handlers) {
VisitUnion(Owner, Parent, Wrapper, handlers...);
}

template <typename ParentTy, typename CurHandler, typename... Handlers>
std::enable_if_t<CurHandler::VisitUnionBody>
VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent,
const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
Handlers &... handlers) {
VisitUnion<CurHandler>(cur_handler, Owner, Parent, Wrapper, handlers...);
}

template <typename... Handlers>
void VisitRecordHelper(const CXXRecordDecl *Owner,
clang::CXXRecordDecl::base_class_const_range Range,
Expand Down Expand Up @@ -943,6 +1005,11 @@ class KernelObjVisitor {
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
VisitRecord(Owner, Field, RD, handlers...);
}
} else if (FieldTy->isUnionType()) {
if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl();
VisitUnion(Owner, Field, RD, handlers...);
}
} else if (FieldTy->isReferenceType())
KF_FOR_EACH(handleReferenceType, Field, FieldTy);
else if (FieldTy->isPointerType())
Expand Down Expand Up @@ -982,6 +1049,7 @@ class SyclKernelFieldHandler {
SyclKernelFieldHandler(Sema &S) : SemaRef(S) {}

public:
static constexpr const bool VisitUnionBody = false;
// Mark these virtual so that we can use override in the implementer classes,
// despite virtual dispatch never being used.

Expand All @@ -1006,6 +1074,7 @@ class SyclKernelFieldHandler {
}
virtual bool handleSyclHalfType(FieldDecl *, QualType) { return true; }
virtual bool handleStructType(FieldDecl *, QualType) { return true; }
virtual bool handleUnionType(FieldDecl *, QualType) { return true; }
virtual bool handleReferenceType(FieldDecl *, QualType) { return true; }
virtual bool handlePointerType(FieldDecl *, QualType) { return true; }
virtual bool handleArrayType(FieldDecl *, QualType) { return true; }
Expand All @@ -1025,6 +1094,8 @@ class SyclKernelFieldHandler {
virtual bool leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) {
return true;
}
virtual bool enterUnion(const CXXRecordDecl *, FieldDecl *) { return true; }
virtual bool leaveUnion(const CXXRecordDecl *, FieldDecl *) { return true; }

// The following are used for stepping through array elements.

Expand All @@ -1047,7 +1118,6 @@ class SyclKernelFieldHandler {
class SyclKernelFieldChecker : public SyclKernelFieldHandler {
bool IsInvalid = false;
DiagnosticsEngine &Diag;

// Check whether the object should be disallowed from being copied to kernel.
// Return true if not copyable, false if copyable.
bool checkNotCopyableToKernel(const FieldDecl *FD, const QualType &FieldTy) {
Expand Down Expand Up @@ -1202,6 +1272,65 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
}
};

// A type to check the validity of accessing accessor/sampler/stream
// types as kernel parameters inside union.
class SyclKernelUnionChecker : public SyclKernelFieldHandler {
int UnionCount = 0;
bool IsInvalid = false;
DiagnosticsEngine &Diag;

public:
SyclKernelUnionChecker(Sema &S)
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
bool isValid() { return !IsInvalid; }
static constexpr const bool VisitUnionBody = true;

bool checkType(SourceLocation Loc, QualType Ty) {
if (UnionCount) {
IsInvalid = true;
Diag.Report(Loc, diag::err_bad_union_kernel_param_members) << Ty;
}
return isValid();
}

bool enterUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
++UnionCount;
return true;
}

bool leaveUnion(const CXXRecordDecl *RD, FieldDecl *FD) {
--UnionCount;
return true;
}

bool handleSyclAccessorType(FieldDecl *FD, QualType FieldTy) final {
return checkType(FD->getLocation(), FieldTy);
}

bool handleSyclAccessorType(const CXXBaseSpecifier &BS,
QualType FieldTy) final {
return checkType(BS.getBeginLoc(), FieldTy);
}

bool handleSyclSamplerType(FieldDecl *FD, QualType FieldTy) final {
return checkType(FD->getLocation(), FieldTy);
}

bool handleSyclSamplerType(const CXXBaseSpecifier &BS,
QualType FieldTy) final {
return checkType(BS.getBeginLoc(), FieldTy);
}

bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final {
return checkType(FD->getLocation(), FieldTy);
}

bool handleSyclStreamType(const CXXBaseSpecifier &BS,
QualType FieldTy) final {
return checkType(BS.getBeginLoc(), FieldTy);
}
};

// A type to Create and own the FunctionDecl for the kernel.
class SyclKernelDeclCreator : public SyclKernelFieldHandler {
FunctionDecl *KernelDecl;
Expand Down Expand Up @@ -1414,6 +1543,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
return true;
}

bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
return handleScalarType(FD, FieldTy);
}

bool handleSyclHalfType(FieldDecl *FD, QualType FieldTy) final {
addParam(FD, FieldTy);
return true;
Expand Down Expand Up @@ -1749,6 +1882,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
return true;
}

bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
return handleScalarType(FD, FieldTy);
}

bool enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
CXXCastPath BasePath;
QualType DerivedTy(RD->getTypeForDecl(), 0);
Expand Down Expand Up @@ -1953,6 +2090,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
return true;
}

bool handleUnionType(FieldDecl *FD, QualType FieldTy) final {
return handleScalarType(FD, FieldTy);
}

bool handleSyclStreamType(FieldDecl *FD, QualType FieldTy) final {
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
return true;
Expand Down Expand Up @@ -2034,14 +2175,14 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
}
}

SyclKernelFieldChecker Checker(*this);

SyclKernelFieldChecker FieldChecker(*this);
SyclKernelUnionChecker UnionChecker(*this);
KernelObjVisitor Visitor{*this};
DiagnosingSYCLKernel = true;
Visitor.VisitRecordBases(KernelObj, Checker);
Visitor.VisitRecordFields(KernelObj, Checker);
Visitor.VisitRecordBases(KernelObj, FieldChecker, UnionChecker);
Visitor.VisitRecordFields(KernelObj, FieldChecker, UnionChecker);
DiagnosingSYCLKernel = false;
if (!Checker.isValid())
if (!FieldChecker.isValid() || !UnionChecker.isValid())
KernelFunc->setInvalidDecl();
}

Expand Down
53 changes: 53 additions & 0 deletions clang/test/CodeGenSYCL/union-kernel-param-ih.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -fsycl-int-header=%t.h %s -o %t.out
// RUN: FileCheck -input-file=%t.h %s

// This test checks the integration header generated when
// the kernel argument is union.

// CHECK: #include <CL/sycl/detail/kernel_desc.hpp>

// CHECK: class kernel_A;

// CHECK: __SYCL_INLINE_NAMESPACE(cl) {
// CHECK-NEXT: namespace sycl {
// CHECK-NEXT: namespace detail {

// CHECK: static constexpr
// CHECK-NEXT: const char* const kernel_names[] = {
// CHECK-NEXT: "_ZTSZ4mainE8kernel_A"
// CHECK-NEXT: };

// CHECK: static constexpr
// CHECK-NEXT: const kernel_param_desc_t kernel_signatures[] = {
// CHECK-NEXT: //--- _ZTSZ4mainE8kernel_A
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 12, 0 },
// CHECK-EMPTY:
// CHECK-NEXT:};

// CHECK: static constexpr
// CHECK-NEXT: const unsigned kernel_signature_start[] = {
// CHECK-NEXT: 0 // _ZTSZ4mainE8kernel_A
// CHECK-NEXT: };

// CHECK: template <> struct KernelInfo<class kernel_A> {

union MyUnion {
int FldInt;
char FldChar;
float FldArr[3];
};

template <typename name, typename Func>
__attribute__((sycl_kernel)) void a_kernel(Func kernelFunc) {
kernelFunc();
}

int main() {

MyUnion obj;

a_kernel<class kernel_A>(
[=]() {
float local = obj.FldArr[2];
});
}
41 changes: 41 additions & 0 deletions clang/test/CodeGenSYCL/union-kernel-param.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: %clang_cc1 -fsycl -fsycl-is-device -I %S/Inputs -triple spir64-unknown-unknown-sycldevice -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s

// This test checks a kernel argument that is union with both array and non-array fields.

#include "sycl.hpp"

using namespace cl::sycl;

union MyUnion {
int FldInt;
char FldChar;
float FldArr[3];
};

template <typename name, typename Func>
__attribute__((sycl_kernel)) void a_kernel(Func kernelFunc) {
kernelFunc();
}

int main() {

MyUnion obj;

a_kernel<class kernel_A>(
[=]() {
float local = obj.FldArr[2];
});
}

// CHECK kernel_A parameters
// CHECK: define spir_kernel void @{{.*}}kernel_A(%union.{{.*}}.MyUnion* byval(%union.{{.*}}.MyUnion) align 4 [[MEM_ARG:%[a-zA-Z0-9_]+]])

// Check lambda object alloca
// CHECK: [[LOCAL_OBJECT:%0]] = alloca %"class.{{.*}}.anon", align 4

// CHECK: [[L_STRUCT_ADDR:%[a-zA-Z0-9_]+]] = getelementptr inbounds %"class.{{.*}}.anon", %"class.{{.*}}.anon"* [[LOCAL_OBJECT]], i32 0, i32 0
// CHECK: [[MEMCPY_DST:%[0-9a-zA-Z_]+]] = bitcast %union.{{.*}}MyUnion* [[L_STRUCT_ADDR]] to i8*
// CHECK: [[MEMCPY_SRC:%[0-9a-zA-Z_]+]] = bitcast %union.{{.*}}MyUnion* [[MEM_ARG]] to i8*
// CHECK: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 4 [[MEMCPY_DST]], i8* align 4 [[MEMCPY_SRC]], i64 12, i1 false)
// CHECK: [[ACC_CAST1:%[0-9]+]] = addrspacecast %"class.{{.*}}.anon"* [[LOCAL_OBJECT]] to %"class.{{.*}}.anon" addrspace(4)*
// CHECK: call spir_func void @{{.*}}(%"class.{{.*}}.anon" addrspace(4)* [[ACC_CAST1]])
42 changes: 42 additions & 0 deletions clang/test/SemaSYCL/union-kernel-param-neg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -verify -fsyntax-only %s

// This test checks if compiler reports compilation error on an attempt to pass
// accessor/sampler as SYCL kernel parameter inside union.

#include "sycl.hpp"
using namespace cl::sycl;

union union_with_sampler {
cl::sycl::sampler smpl;
// expected-error@-1 {{'cl::sycl::sampler' cannot be used inside a union kernel parameter}}
};

template <typename name, typename Func>
__attribute__((sycl_kernel)) void a_kernel(Func kernelFunc) {
kernelFunc();
}

int main() {

using Accessor =
accessor<int, 1, access::mode::read_write, access::target::global_buffer>;

union union_with_accessor {
Accessor member_acc[1];
// expected-error@-1 {{'Accessor' (aka 'accessor<int, 1, access::mode::read_write, access::target::global_buffer>') cannot be used inside a union kernel parameter}}
} union_acc;

union_with_sampler Sampler;

a_kernel<class kernel_A>(
[=]() {
Sampler.smpl.use();
});

a_kernel<class kernel_B>(
[=]() {
union_acc.member_acc[1].use();
});

return 0;
}
Loading