Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 9a5a35f

Browse files
committed
[SYCL] Add multi_ptr test
1 parent 6e8d25a commit 9a5a35f

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed

SYCL/Basic/multi_ptr.cpp

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -fsycl-dead-args-optimization %s -o %t.out
2+
// RUN: %HOST_RUN_PLACEHOLDER %t.out
3+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
4+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
5+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
6+
// RUN: %clangxx -DRESTRICT_WRITE_ACCESS_TO_CONSTANT_PTR -fsycl -fsycl-targets=%sycl_triple -fsycl-dead-args-optimization %s -o %t1.out
7+
// RUN: %HOST_RUN_PLACEHOLDER %t1.out
8+
// RUN: %CPU_RUN_PLACEHOLDER %t1.out
9+
// RUN: %GPU_RUN_PLACEHOLDER %t1.out
10+
// RUN: %ACC_RUN_PLACEHOLDER %t1.out
11+
12+
//==--------------- multi_ptr.cpp - SYCL multi_ptr test --------------------==//
13+
//
14+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
15+
// See https://llvm.org/LICENSE.txt for license information.
16+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
17+
//
18+
//===----------------------------------------------------------------------===//
19+
20+
#include <CL/sycl.hpp>
21+
#include <cassert>
22+
#include <iostream>
23+
#include <type_traits>
24+
25+
using namespace cl::sycl;
26+
27+
/* This is the class used to name the kernel for the runtime.
28+
* This must be done when the kernel is expressed as a lambda. */
29+
template <typename T> class testMultPtrKernel;
30+
template <typename T> class testMultPtrArrowOperatorKernel;
31+
32+
template <typename T> struct point {
33+
point(const point &rhs) : x(rhs.x), y(rhs.y) {}
34+
point(T x, T y) : x(x), y(y) {}
35+
point(T v) : x(v), y(v) {}
36+
point() : x(0), y(0) {}
37+
bool operator==(const T &rhs) { return rhs == x && rhs == y; }
38+
bool operator==(const point<T> &rhs) { return rhs.x == x && rhs.y == y; }
39+
T x;
40+
T y;
41+
};
42+
43+
template <typename T>
44+
void innerFunc(id<1> wiID, global_ptr<T> ptr_1, global_ptr<T> ptr_2,
45+
local_ptr<T> local_ptr) {
46+
T t = ptr_1[wiID.get(0)];
47+
local_ptr[wiID.get(0)] = t;
48+
t = local_ptr[wiID.get(0)];
49+
ptr_2[wiID.get(0)] = t;
50+
}
51+
52+
template <typename T> void testMultPtr() {
53+
T data_1[10];
54+
for (size_t i = 0; i < 10; ++i) {
55+
data_1[i] = 1;
56+
}
57+
T data_2[10];
58+
for (size_t i = 0; i < 10; ++i) {
59+
data_2[i] = 2;
60+
}
61+
62+
{
63+
range<1> numOfItems{10};
64+
buffer<T, 1> bufferData_1(data_1, numOfItems);
65+
buffer<T, 1> bufferData_2(data_2, numOfItems);
66+
queue myQueue;
67+
myQueue.submit([&](handler &cgh) {
68+
accessor<T, 1, access::mode::read, access::target::global_buffer,
69+
access::placeholder::false_t>
70+
accessorData_1(bufferData_1, cgh);
71+
accessor<T, 1, access::mode::read_write, access::target::global_buffer,
72+
access::placeholder::false_t>
73+
accessorData_2(bufferData_2, cgh);
74+
accessor<T, 1, access::mode::read_write, access::target::local>
75+
localAccessor(numOfItems, cgh);
76+
77+
cgh.parallel_for<class testMultPtrKernel<T>>(range<1>{10}, [=](id<1> wiID) {
78+
auto ptr_1 = make_ptr<T, access::address_space::global_space>(
79+
accessorData_1.get_pointer());
80+
auto ptr_2 = make_ptr<T, access::address_space::global_space>(
81+
accessorData_2.get_pointer());
82+
auto local_ptr = make_ptr<T, access::address_space::local_space>(
83+
localAccessor.get_pointer());
84+
85+
// General conversions in multi_ptr class
86+
T *RawPtr = nullptr;
87+
global_ptr<T> ptr_4(RawPtr);
88+
ptr_4 = RawPtr;
89+
90+
global_ptr<T> ptr_5(accessorData_1);
91+
92+
global_ptr<void> ptr_6((void *)RawPtr);
93+
94+
ptr_6 = (void *)RawPtr;
95+
96+
// Explicit conversions for device_ptr/host_ptr to global_ptr
97+
device_ptr<void> ptr_7((void *)RawPtr);
98+
global_ptr<void> ptr_8 = global_ptr<void>(ptr_7);
99+
host_ptr<void> ptr_9((void *)RawPtr);
100+
global_ptr<void> ptr_10 = global_ptr<void>(ptr_9);
101+
// TODO: need propagation of a7b763b26 patch to acl tool before testing
102+
// these conversions - otherwise the test would fail on accelerator
103+
// device during reversed translation from SPIR-V to LLVM IR
104+
// device_ptr<T> ptr_11(accessorData_1);
105+
// global_ptr<T> ptr_12 = global_ptr<T>(ptr_11);
106+
107+
innerFunc<T>(wiID.get(0), ptr_1, ptr_2, local_ptr);
108+
});
109+
});
110+
}
111+
for (size_t i = 0; i < 10; ++i) {
112+
assert(data_1[i] == 1 && "Expected data_1[i] == 1");
113+
}
114+
for (size_t i = 0; i < 10; ++i) {
115+
assert(data_2[i] == 1 && "Expected data_2[i] == 1");
116+
}
117+
}
118+
119+
template <typename T> void testMultPtrArrowOperator() {
120+
point<T> data_1[1] = {1};
121+
point<T> data_2[1] = {2};
122+
point<T> data_3[1] = {3};
123+
point<T> data_4[1] = {4};
124+
125+
{
126+
range<1> numOfItems{1};
127+
buffer<point<T>, 1> bufferData_1(data_1, numOfItems);
128+
buffer<point<T>, 1> bufferData_2(data_2, numOfItems);
129+
buffer<point<T>, 1> bufferData_3(data_3, numOfItems);
130+
buffer<point<T>, 1> bufferData_4(data_4, numOfItems);
131+
queue myQueue;
132+
myQueue.submit([&](handler &cgh) {
133+
accessor<point<T>, 1, access::mode::read, access::target::global_buffer,
134+
access::placeholder::false_t>
135+
accessorData_1(bufferData_1, cgh);
136+
accessor<point<T>, 1, access::mode::read, access::target::constant_buffer,
137+
access::placeholder::false_t>
138+
accessorData_2(bufferData_2, cgh);
139+
accessor<point<T>, 1, access::mode::read_write, access::target::local,
140+
access::placeholder::false_t>
141+
accessorData_3(1, cgh);
142+
accessor<point<T>, 1, access::mode::read, access::target::global_buffer,
143+
access::placeholder::false_t>
144+
accessorData_4(bufferData_4, cgh);
145+
146+
cgh.single_task<class testMultPtrArrowOperatorKernel<T>>([=]() {
147+
auto ptr_1 = make_ptr<point<T>, access::address_space::global_space>(
148+
accessorData_1.get_pointer());
149+
auto ptr_2 = make_ptr<point<T>, access::address_space::constant_space>(
150+
accessorData_2.get_pointer());
151+
auto ptr_3 = make_ptr<point<T>, access::address_space::local_space>(
152+
accessorData_3.get_pointer());
153+
auto ptr_4 =
154+
make_ptr<point<T>, access::address_space::global_device_space>(
155+
accessorData_4.get_pointer());
156+
157+
auto x1 = ptr_1 -> x;
158+
auto x2 = ptr_2 -> x;
159+
auto x3 = ptr_3 -> x;
160+
auto x4 = ptr_4 -> x;
161+
162+
static_assert(std::is_same<decltype(x1), T>::value,
163+
"Expected decltype(ptr_1->x) == T");
164+
static_assert(std::is_same<decltype(x2), T>::value,
165+
"Expected decltype(ptr_2->x) == T");
166+
static_assert(std::is_same<decltype(x3), T>::value,
167+
"Expected decltype(ptr_3->x) == T");
168+
static_assert(std::is_same<decltype(x4), T>::value,
169+
"Expected decltype(ptr_4->x) == T");
170+
});
171+
});
172+
}
173+
}
174+
175+
int main() {
176+
testMultPtr<int>();
177+
testMultPtr<float>();
178+
testMultPtr<point<int>>();
179+
testMultPtr<point<float>>();
180+
181+
testMultPtrArrowOperator<int>();
182+
testMultPtrArrowOperator<float>();
183+
184+
return 0;
185+
}

0 commit comments

Comments
 (0)