Skip to content

Commit 2ad1d20

Browse files
committed
[SYCL] Add check that accessor class declared in cl::sycl namespace.
Signed-off-by: Vladimir Lazarev <[email protected]>
1 parent 124a88a commit 2ad1d20

File tree

3 files changed

+87
-13
lines changed

3 files changed

+87
-13
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -240,33 +240,32 @@ void BuildArgTys(ASTContext &Context,
240240
for (auto V : ArgDecls) {
241241
QualType ArgTy = V->getType();
242242
QualType ActualArgType = ArgTy;
243-
StringRef Name = ArgTy.getBaseTypeIdentifier()->getName();
244-
// TODO: harden this check with additional validation that this class is
245-
// declared in cl::sycl namespace
246-
if (std::string(Name) == "accessor") {
243+
std::string Name = ArgTy.getCanonicalType().getAsString();
244+
if (Name.find("class cl::sycl::accessor") != std::string::npos) {
247245
if (const auto *RecordDecl = ArgTy->getAsCXXRecordDecl()) {
248246
const auto *TemplateDecl =
249247
dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl);
250248
if (TemplateDecl) {
251249
// First parameter - data type
252250
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType();
253251
// Fourth parameter - access target
254-
auto AccessQualifier = TemplateDecl->getTemplateArgs()[3].getAsIntegral();
252+
auto AccessQualifier =
253+
TemplateDecl->getTemplateArgs()[3].getAsIntegral();
255254
int64_t AccessTarget = AccessQualifier.getExtValue();
256255
Qualifiers Quals = PointeeType.getQualifiers();
257256
// TODO: Support all access targets
258257
switch (AccessTarget) {
259-
case target::global_buffer:
258+
case target::global_buffer:
260259
Quals.setAddressSpace(LangAS::opencl_global);
261-
break;
262-
case target::constant_buffer:
260+
break;
261+
case target::constant_buffer:
263262
Quals.setAddressSpace(LangAS::opencl_constant);
264-
break;
265-
case target::local:
263+
break;
264+
case target::local:
266265
Quals.setAddressSpace(LangAS::opencl_local);
267-
break;
268-
default:
269-
llvm_unreachable("Unsupported access target");
266+
break;
267+
default:
268+
llvm_unreachable("Unsupported access target");
270269
}
271270
// TODO: get address space from accessor template parameter.
272271
PointeeType =
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %clang -S --sycl -Xclang -ast-dump %s | FileCheck %s
2+
// XFAIL: *
3+
#include <CL/sycl.hpp>
4+
5+
int main() {
6+
int data = 5;
7+
cl::sycl::queue deviceQueue;
8+
cl::sycl::buffer<int, 1> bufferA(&data, cl::sycl::range<1>(1));
9+
10+
deviceQueue.submit([&](cl::sycl::handler &cgh) {
11+
auto accessorA = bufferA.template get_access<cl::sycl::access::mode::read_write>(cgh);
12+
cgh.single_task<class kernel_function>(
13+
[=]() {
14+
accessorA[0] += data;
15+
});
16+
});
17+
return 0;
18+
}
19+
// CHECK: kernel_function 'void (__global int *__global, int)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: %clang -S --sycl -Xclang -ast-dump %s | FileCheck %s
2+
// XFAIL: *
3+
#include <CL/sycl.hpp>
4+
5+
namespace foo {
6+
namespace cl {
7+
namespace sycl {
8+
class accessor {
9+
public:
10+
int field;
11+
};
12+
} // namespace sycl
13+
} // namespace cl
14+
} // namespace foo
15+
16+
class accessor {
17+
public:
18+
int field;
19+
};
20+
21+
typedef cl::sycl::accessor<int, 1, cl::sycl::access::mode::read_write,
22+
cl::sycl::access::target::global_buffer>
23+
MyAccessorTD;
24+
25+
using MyAccessorA = cl::sycl::accessor<int, 1, cl::sycl::access::mode::read_write,
26+
cl::sycl::access::target::global_buffer>;
27+
28+
int main() {
29+
int data = 5;
30+
cl::sycl::queue deviceQueue;
31+
cl::sycl::buffer<int, 1> bufferA(&data, cl::sycl::range<1>(1));
32+
foo::cl::sycl::accessor acc = {1};
33+
accessor acc1 = {1};
34+
35+
deviceQueue.submit([&](cl::sycl::handler &cgh) {
36+
auto accessorA = bufferA.template get_access<cl::sycl::access::mode::read_write>(cgh);
37+
MyAccessorTD accessorB = bufferA.template get_access<cl::sycl::access::mode::read_write>(cgh);
38+
MyAccessorA accessorC = bufferA.template get_access<cl::sycl::access::mode::read_write>(cgh);
39+
cgh.single_task<class fake_accessors>(
40+
[=]() {
41+
accessorA[0] = acc.field + acc1.field;
42+
});
43+
cgh.single_task<class accessor_typedef>(
44+
[=]() {
45+
accessorB[0] = acc.field + acc1.field;
46+
});
47+
cgh.single_task<class accessor_alias>(
48+
[=]() {
49+
accessorC[0] = acc.field + acc1.field;
50+
});
51+
});
52+
return 0;
53+
}
54+
// CHECK: fake_accessors 'void (__global int *__global, foo::cl::sycl::accessor, accessor)
55+
// CHECK: accessor_typedef 'void (__global int *__global, foo::cl::sycl::accessor, accessor)
56+
// CHECK: accessor_alias 'void (__global int *__global, foo::cl::sycl::accessor, accessor)

0 commit comments

Comments
 (0)