Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 059256f

Browse files
committed
Fix test according to API semantics change.
Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent 6b803e7 commit 059256f

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

SYCL/ESIMD/api/esimd_merge.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ int main(void) {
4747
using MaskElT = typename simd_mask<1>::element_type;
4848
MaskElT *M = malloc_shared<MaskElT>(Size, q);
4949
int *C = malloc_shared<int>(Size, q);
50+
int *C1 = malloc_shared<int>(Size, q);
5051
constexpr int VAL0 = 0;
5152
constexpr int VAL1 = 1;
5253
constexpr int VAL2 = 3;
@@ -58,6 +59,7 @@ int main(void) {
5859
// bit representation in a mask element:
5960
M[i] = (i % SUB_VL) >= (SUB_VL / 2) ? (i % SUB_VL - 1) : 0; // 00120012 ...
6061
C[i] = VAL0;
62+
C1[i] = VAL0;
6163
}
6264

6365
try {
@@ -71,11 +73,20 @@ int main(void) {
7173
// m: 0012001200120012
7274
// va.sel.sel: 1111
7375
// vb.sel.sel: 4444
74-
// vc: 1144
76+
// vc: 4411
7577
simd<int, SUB_VL> vc =
7678
esimd::merge(va.select<SUB_VL * 2, 2>(0).select<SUB_VL, 1>(1),
7779
vb.select<SUB_VL * 1, 2>(1).select<SUB_VL, 1>(0), m);
7880
vc.copy_to(C + i * VL);
81+
82+
// also check that
83+
// vc = esimd::merge(a, b, m)
84+
// is equivalent to
85+
// vc.merge(a, b, m)
86+
simd<int, SUB_VL> vc1;
87+
vc1.merge(va.select<SUB_VL * 2, 2>(0).select<SUB_VL, 1>(1).read(),
88+
vb.select<SUB_VL * 1, 2>(1).select<SUB_VL, 1>(0).read(), m);
89+
vc1.copy_to(C1 + i * VL);
7990
});
8091
});
8192
e.wait();
@@ -84,6 +95,7 @@ int main(void) {
8495
sycl::free(A, q);
8596
sycl::free(B, q);
8697
sycl::free(C, q);
98+
sycl::free(C1, q);
8799
sycl::free(M, q);
88100
return 1;
89101
}
@@ -98,14 +110,20 @@ int main(void) {
98110
for (int i = 0; i < Size; ++i) {
99111
int j = i % VL;
100112
int gold =
101-
j >= SUB_VL ? VAL0 : ((j % SUB_VL) >= (SUB_VL / 2) ? VAL2 + 1 : VAL1);
113+
j >= SUB_VL ? VAL0 : ((j % SUB_VL) >= (SUB_VL / 2) ? VAL1 : VAL2 + 1);
102114

103115
if (C[i] != gold) {
104116
if (++err_cnt < 10) {
105-
std::cout << "failed at index " << i << ", " << C[i] << " != " << gold
117+
std::cout << "(esimd::merge) failed at index " << i << ", " << C[i] << " != " << gold
106118
<< " (gold)\n";
107119
}
108120
}
121+
if (C1[i] != gold) {
122+
if (++err_cnt < 10) {
123+
std::cout << "(simd::merge) failed at index " << i << ", " << C1[i] << " != " << gold
124+
<< " (gold)\n";
125+
}
126+
}
109127
}
110128
if (err_cnt > 0) {
111129
std::cout << " pass rate: "
@@ -116,6 +134,7 @@ int main(void) {
116134
sycl::free(A, q);
117135
sycl::free(B, q);
118136
sycl::free(C, q);
137+
sycl::free(C1, q);
119138
sycl::free(M, q);
120139
std::cout << (err_cnt > 0 ? "FAILED\n" : "Passed\n");
121140
return err_cnt > 0 ? 1 : 0;

0 commit comments

Comments
 (0)