Skip to content

Commit 9b9639a

Browse files
authored
[SYCL] Correction to recognition of constant-size array. (#2082)
Signed-off-by: rdeodhar <[email protected]>
1 parent e5341eb commit 9b9639a

File tree

3 files changed

+71
-6
lines changed

3 files changed

+71
-6
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,9 @@ class KernelObjVisitor {
802802
template <typename... Handlers>
803803
void VisitArrayElements(FieldDecl *FD, QualType FieldTy,
804804
Handlers &... handlers) {
805-
const ConstantArrayType *CAT = cast<ConstantArrayType>(FieldTy);
805+
const ConstantArrayType *CAT =
806+
SemaRef.getASTContext().getAsConstantArrayType(FieldTy);
807+
assert(CAT && "Should only be called on constant-size array.");
806808
QualType ET = CAT->getElementType();
807809
int64_t ElemCount = CAT->getSize().getSExtValue();
808810
std::initializer_list<int>{(handlers.enterArray(), 0)...};
@@ -1014,7 +1016,8 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
10141016
// Return true if not copyable, false if copyable.
10151017
bool checkNotCopyableToKernel(const FieldDecl *FD, const QualType &FieldTy) {
10161018
if (FieldTy->isArrayType()) {
1017-
if (const auto *CAT = dyn_cast<ConstantArrayType>(FieldTy)) {
1019+
if (const auto *CAT =
1020+
SemaRef.getASTContext().getAsConstantArrayType(FieldTy)) {
10181021
QualType ET = CAT->getElementType();
10191022
return checkNotCopyableToKernel(FD, ET);
10201023
}

clang/test/SemaSYCL/array-kernel-param.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ __attribute__((sycl_kernel)) void a_kernel(Func kernelFunc) {
1212
kernelFunc();
1313
}
1414

15+
template <typename T>
16+
struct S {
17+
T a[3];
18+
};
19+
1520
int main() {
1621

1722
using Accessor =
@@ -22,6 +27,7 @@ int main() {
2227
struct struct_acc_t {
2328
Accessor member_acc[2];
2429
} struct_acc;
30+
S<int> s;
2531

2632
struct foo_inner {
2733
int foo_inner_x;
@@ -56,6 +62,11 @@ int main() {
5662
[=]() {
5763
foo local = struct_array[1];
5864
});
65+
66+
a_kernel<class kernel_E>(
67+
[=]() {
68+
int local = s.a[2];
69+
});
5970
}
6071

6172
// Check kernel_A parameters
@@ -100,8 +111,8 @@ int main() {
100111
// CHECK-NEXT: ParmVarDecl {{.*}} used _arg_member_acc 'cl::sycl::id<1>'
101112
// CHECK-NEXT: CompoundStmt
102113
// CHECK-NEXT: DeclStmt
103-
// CHECK-NEXT: VarDecl {{.*}} used '(lambda at {{.*}}array-kernel-param.cpp{{.*}})' cinit
104-
// CHECK-NEXT: InitListExpr {{.*}} '(lambda at {{.*}}array-kernel-param.cpp{{.*}})'
114+
// CHECK-NEXT: VarDecl {{.*}} used '(lambda at {{.*}}array-kernel-param.cpp:57:7)' cinit
115+
// CHECK-NEXT: InitListExpr {{.*}} '(lambda at {{.*}}array-kernel-param.cpp:57:7)'
105116
// CHECK-NEXT: InitListExpr {{.*}} 'struct_acc_t'
106117
// CHECK-NEXT: InitListExpr {{.*}} 'Accessor [2]'
107118
// CHECK-NEXT: CXXConstructExpr {{.*}} 'Accessor [2]'
@@ -201,3 +212,21 @@ int main() {
201212
// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} '_arg_foo_inner_z' 'int'
202213
// CHECK-NEXT: ImplicitCastExpr
203214
// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} '_arg_foo_c' 'int'
215+
216+
// Check kernel_E parameters
217+
// CHECK: FunctionDecl {{.*}}kernel_E{{.*}} 'void (int, int, int)'
218+
// CHECK-NEXT: ParmVarDecl {{.*}} used _arg_a 'int':'int'
219+
// CHECK-NEXT: ParmVarDecl {{.*}} used _arg_a 'int':'int'
220+
// CHECK-NEXT: ParmVarDecl {{.*}} used _arg_a 'int':'int'
221+
// CHECK-NEXT: CompoundStmt
222+
// CHECK-NEXT: DeclStmt
223+
// CHECK-NEXT: VarDecl {{.*}} used '(lambda at {{.*}}array-kernel-param.cpp:67:7)' cinit
224+
// CHECK-NEXT: InitListExpr {{.*}} '(lambda at {{.*}}array-kernel-param.cpp:67:7)'
225+
// CHECK-NEXT: InitListExpr {{.*}} 'S<int>'
226+
// CHECK-NEXT: InitListExpr {{.*}} 'int [3]'
227+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int':'int'
228+
// CHECK-NEXT: DeclRefExpr {{.*}} 'int':'int'
229+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int':'int'
230+
// CHECK-NEXT: DeclRefExpr {{.*}} 'int':'int'
231+
// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int':'int'
232+
// CHECK-NEXT: DeclRefExpr {{.*}} 'int':'int'

sycl/test/array_param/array-kernel-param-nested-run.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
using namespace cl::sycl;
1313

14-
constexpr size_t c_num_items = 100;
14+
constexpr size_t c_num_items = 10;
1515
range<1> num_items{c_num_items}; // range<1>(num_items)
1616

1717
// Change if tests are added/removed
18-
static int testCount = 1;
18+
static int testCount = 2;
1919
static int passCount;
2020

2121
template <typename T>
@@ -84,6 +84,36 @@ bool test_accessor_array_in_struct(queue &myQueue) {
8484
return verify_1D("Accessor array in struct", c_num_items, output, ref);
8585
}
8686

87+
template <typename T>
88+
struct S {
89+
T a[c_num_items];
90+
};
91+
bool test_templated_array_in_struct(queue &myQueue) {
92+
std::array<int, c_num_items> output;
93+
std::array<int, c_num_items> ref;
94+
init(ref, 3, 3);
95+
96+
auto out_buffer = buffer<int, 1>(output.data(), num_items);
97+
98+
S<int> sint;
99+
S<long long> sll;
100+
init(sint.a, 1, 1);
101+
init(sll.a, 2, 2);
102+
103+
myQueue.submit([&](handler &cgh) {
104+
using Accessor =
105+
accessor<int, 1, access::mode::read_write, access::target::global_buffer>;
106+
auto output_accessor = out_buffer.get_access<access::mode::write>(cgh);
107+
108+
cgh.parallel_for<class templated_array_in_struct>(num_items, [=](cl::sycl::id<1> index) {
109+
output_accessor[index] = sint.a[index] + sll.a[index];
110+
});
111+
});
112+
const auto HostAccessor = out_buffer.get_access<cl::sycl::access::mode::read>();
113+
114+
return verify_1D("Templated array in struct", c_num_items, output, ref);
115+
}
116+
87117
bool run_tests() {
88118
queue Q([](exception_list L) {
89119
for (auto ep : L) {
@@ -103,6 +133,9 @@ bool run_tests() {
103133
if (test_accessor_array_in_struct(Q)) {
104134
++passCount;
105135
}
136+
if (test_templated_array_in_struct(Q)) {
137+
++passCount;
138+
}
106139

107140
auto D = Q.get_device();
108141
const char *devType = D.is_host() ? "Host" : D.is_cpu() ? "CPU" : "GPU";

0 commit comments

Comments
 (0)