Skip to content

Commit 2dff81d

Browse files
committed
Generalize test code a bit
1 parent d684185 commit 2dff81d

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

sycl/unittests/accessor/AccessorIterator.cpp

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,22 @@ class AccessorIteratorTest : public ::testing::Test {
3030
}
3131
}
3232

33+
template <int TotalDimensions, int CurrentDimension = 3, typename Container,
34+
typename... Indices>
35+
auto &&accessHelper(Container &&C, int Idx, Indices... Ids) {
36+
if constexpr (CurrentDimension > TotalDimensions) {
37+
(void)Idx;
38+
return accessHelper<TotalDimensions, CurrentDimension - 1>(C, Ids...);
39+
} else
40+
return accessHelper<TotalDimensions, CurrentDimension - 1>(C[Idx],
41+
Ids...);
42+
}
43+
44+
template <int TotalDimensions, int CurrentDimension = 3, typename Container>
45+
auto &&accessHelper(Container &&C, int Idx) {
46+
return C[Idx];
47+
}
48+
3349
template <int Dimensions, typename T = int>
3450
void checkPartialCopyThroughIteratorWithoutOffset(
3551
const sycl::range<Dimensions> &fullShape,
@@ -52,27 +68,16 @@ class AccessorIteratorTest : public ::testing::Test {
5268

5369
{
5470
auto fullAccessor = buffer.template get_access<sycl::access_mode::read>();
55-
56-
if constexpr (Dimensions == 1) {
57-
for (size_t x = 0; x < copyShape[0]; ++x) {
58-
ASSERT_EQ(copied[x], reference[x]);
59-
}
60-
} else if constexpr (Dimensions == 2) {
61-
size_t linear = 0;
62-
for (size_t y = 0; y < copyShape[0]; ++y) {
63-
for (size_t x = 0; x < copyShape[1]; ++x) {
64-
ASSERT_EQ(copied[linear], fullAccessor[y][x]);
65-
++linear;
66-
}
67-
}
68-
} else {
69-
size_t linear = 0;
70-
for (size_t z = 0; z < copyShape[0]; ++z) {
71-
for (size_t y = 0; y < copyShape[1]; ++y) {
72-
for (size_t x = 0; x < copyShape[2]; ++x) {
73-
ASSERT_EQ(copied[linear], fullAccessor[z][y][x]);
74-
++linear;
75-
}
71+
size_t linearId = 0;
72+
sycl::id<3> shapeToCheck(Dimensions > 2 ? copyShape[Dimensions - 3] : 1,
73+
Dimensions > 1 ? copyShape[Dimensions - 2] : 1,
74+
copyShape[Dimensions - 1]);
75+
for (size_t z = 0; z < shapeToCheck[0]; ++z) {
76+
for (size_t y = 0; y < shapeToCheck[1]; ++y) {
77+
for (size_t x = 0; x < shapeToCheck[2]; ++x) {
78+
auto value = accessHelper<Dimensions>(fullAccessor, z, y, x);
79+
ASSERT_EQ(copied[linearId], value);
80+
++linearId;
7681
}
7782
}
7883
}

0 commit comments

Comments
 (0)