Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 119307e

Browse files
authored
[SYCL] Add local_accessor iterator tests (#1225)
sycl patch: intel/llvm#6692
1 parent bc5e945 commit 119307e

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

SYCL/Basic/accessor/accessor.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,62 @@ template <typename Acc> struct Wrapper2 {
6060

6161
template <typename Acc> struct Wrapper3 { Wrapper2<Acc> w2; };
6262

63+
template <typename GlobAcc, typename LocAcc>
64+
void testLocalAccItersImpl(sycl::handler &cgh, GlobAcc &globAcc, LocAcc &locAcc,
65+
bool testConstIter) {
66+
if (testConstIter) {
67+
cgh.single_task([=]() {
68+
size_t Idx = 0;
69+
for (auto &It : locAcc) {
70+
It = globAcc[Idx++];
71+
}
72+
Idx = 0;
73+
for (auto It = locAcc.cbegin(); It != locAcc.cend(); It++)
74+
globAcc[Idx++] = *It * 2 + 1;
75+
Idx = locAcc.size() - 1;
76+
for (auto It = locAcc.crbegin(); It != locAcc.crend(); It++)
77+
globAcc[Idx--] += *It;
78+
});
79+
} else {
80+
cgh.single_task([=]() {
81+
size_t Idx = 0;
82+
for (auto It = locAcc.begin(); It != locAcc.end(); It++)
83+
*It = globAcc[Idx++] * 2;
84+
for (auto &It : locAcc) {
85+
It++;
86+
}
87+
for (auto It = locAcc.rbegin(); It != locAcc.rend(); It++) {
88+
*It *= 2;
89+
*It += 1;
90+
}
91+
Idx = 0;
92+
for (auto &It : locAcc) {
93+
globAcc[Idx++] = It;
94+
}
95+
});
96+
}
97+
}
98+
99+
void testLocalAccIters(std::vector<int> &vec, bool testConstIter = false,
100+
bool test2D = false) {
101+
try {
102+
sycl::queue queue;
103+
sycl::buffer<int, 1> buf(vec.data(), vec.size());
104+
queue.submit([&](sycl::handler &cgh) {
105+
auto globAcc = buf.get_access<sycl::access::mode::read_write>(cgh);
106+
if (test2D) {
107+
sycl::local_accessor<int, 2> locAcc(sycl::range<2>{2, 16}, cgh);
108+
testLocalAccItersImpl(cgh, globAcc, locAcc, testConstIter);
109+
} else {
110+
sycl::local_accessor<int, 1> locAcc(32, cgh);
111+
testLocalAccItersImpl(cgh, globAcc, locAcc, testConstIter);
112+
}
113+
});
114+
} catch (sycl::exception &e) {
115+
std::cout << e.what() << std::endl;
116+
}
117+
}
118+
63119
int main() {
64120
// Host accessor.
65121
{
@@ -771,5 +827,40 @@ int main() {
771827
}
772828
}
773829

830+
// Test iterator methods with 1D local_accessor
831+
{
832+
std::vector<int> v(32);
833+
for (int i = 0; i < v.size(); ++i) {
834+
v[i] = i;
835+
}
836+
testLocalAccIters(v);
837+
for (int i = 0; i < v.size(); ++i)
838+
assert(v[i] == ((i * 2 + 1) * 2 + 1));
839+
840+
for (int i = 0; i < v.size(); ++i) {
841+
v[i] = i;
842+
}
843+
testLocalAccIters(v, true);
844+
for (int i = 0; i < v.size(); ++i)
845+
assert(v[i] == ((i * 2 + 1) + i));
846+
}
847+
// Test iterator methods with 2D local_accessor
848+
{
849+
std::vector<int> v(32);
850+
for (int i = 0; i < v.size(); ++i) {
851+
v[i] = i;
852+
}
853+
testLocalAccIters(v, false, true);
854+
for (int i = 0; i < v.size(); ++i)
855+
assert(v[i] == ((i * 2 + 1) * 2 + 1));
856+
857+
for (int i = 0; i < v.size(); ++i) {
858+
v[i] = i;
859+
}
860+
testLocalAccIters(v, true, true);
861+
for (int i = 0; i < v.size(); ++i)
862+
assert(v[i] == ((i * 2 + 1) + i));
863+
}
864+
774865
std::cout << "Test passed" << std::endl;
775866
}

0 commit comments

Comments
 (0)