@@ -47,6 +47,7 @@ int main(void) {
47
47
using MaskElT = typename simd_mask<1 >::element_type;
48
48
MaskElT *M = malloc_shared<MaskElT>(Size, q);
49
49
int *C = malloc_shared<int >(Size, q);
50
+ int *C1 = malloc_shared<int >(Size, q);
50
51
constexpr int VAL0 = 0 ;
51
52
constexpr int VAL1 = 1 ;
52
53
constexpr int VAL2 = 3 ;
@@ -58,6 +59,7 @@ int main(void) {
58
59
// bit representation in a mask element:
59
60
M[i] = (i % SUB_VL) >= (SUB_VL / 2 ) ? (i % SUB_VL - 1 ) : 0 ; // 00120012 ...
60
61
C[i] = VAL0;
62
+ C1[i] = VAL0;
61
63
}
62
64
63
65
try {
@@ -71,11 +73,20 @@ int main(void) {
71
73
// m: 0012001200120012
72
74
// va.sel.sel: 1111
73
75
// vb.sel.sel: 4444
74
- // vc: 1144
76
+ // vc: 4411
75
77
simd<int , SUB_VL> vc =
76
78
esimd::merge (va.select <SUB_VL * 2 , 2 >(0 ).select <SUB_VL, 1 >(1 ),
77
79
vb.select <SUB_VL * 1 , 2 >(1 ).select <SUB_VL, 1 >(0 ), m);
78
80
vc.copy_to (C + i * VL);
81
+
82
+ // also check that
83
+ // vc = esimd::merge(a, b, m)
84
+ // is equivalent to
85
+ // vc.merge(a, b, m)
86
+ simd<int , SUB_VL> vc1;
87
+ vc1.merge (va.select <SUB_VL * 2 , 2 >(0 ).select <SUB_VL, 1 >(1 ).read (),
88
+ vb.select <SUB_VL * 1 , 2 >(1 ).select <SUB_VL, 1 >(0 ).read (), m);
89
+ vc1.copy_to (C1 + i * VL);
79
90
});
80
91
});
81
92
e.wait ();
@@ -84,6 +95,7 @@ int main(void) {
84
95
sycl::free (A, q);
85
96
sycl::free (B, q);
86
97
sycl::free (C, q);
98
+ sycl::free (C1, q);
87
99
sycl::free (M, q);
88
100
return 1 ;
89
101
}
@@ -98,14 +110,20 @@ int main(void) {
98
110
for (int i = 0 ; i < Size; ++i) {
99
111
int j = i % VL;
100
112
int gold =
101
- j >= SUB_VL ? VAL0 : ((j % SUB_VL) >= (SUB_VL / 2 ) ? VAL2 + 1 : VAL1 );
113
+ j >= SUB_VL ? VAL0 : ((j % SUB_VL) >= (SUB_VL / 2 ) ? VAL1 : VAL2 + 1 );
102
114
103
115
if (C[i] != gold) {
104
116
if (++err_cnt < 10 ) {
105
- std::cout << " failed at index " << i << " , " << C[i] << " != " << gold
117
+ std::cout << " (esimd::merge) failed at index " << i << " , " << C[i] << " != " << gold
106
118
<< " (gold)\n " ;
107
119
}
108
120
}
121
+ if (C1[i] != gold) {
122
+ if (++err_cnt < 10 ) {
123
+ std::cout << " (simd::merge) failed at index " << i << " , " << C1[i] << " != " << gold
124
+ << " (gold)\n " ;
125
+ }
126
+ }
109
127
}
110
128
if (err_cnt > 0 ) {
111
129
std::cout << " pass rate: "
@@ -116,6 +134,7 @@ int main(void) {
116
134
sycl::free (A, q);
117
135
sycl::free (B, q);
118
136
sycl::free (C, q);
137
+ sycl::free (C1, q);
119
138
sycl::free (M, q);
120
139
std::cout << (err_cnt > 0 ? " FAILED\n " : " Passed\n " );
121
140
return err_cnt > 0 ? 1 : 0 ;
0 commit comments