Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 0565e79

Browse files
authored
[ESIMD] Add VL=1 test case to slm_gather/scatter test. (#522)
Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent 94831aa commit 0565e79

File tree

1 file changed

+68
-1
lines changed

1 file changed

+68
-1
lines changed

SYCL/ESIMD/api/slm_gather_scatter_heavy.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,62 @@ struct ScatterKernel : KernelBase<T, VL, STRIDE> {
229229
}
230230
};
231231

232+
// Partial specialization of the gather kernel to test vector length = 1.
233+
template <class T, unsigned STRIDE>
234+
struct GatherKernel<T, 1, STRIDE, TEST_VECTOR_NO_MASK>
235+
: KernelBase<T, 1, STRIDE> {
236+
using B = KernelBase<T, 1, STRIDE>;
237+
using B::B;
238+
239+
static const char *get_name() { return "slm_gather_vl1"; }
240+
241+
void operator()(nd_item<1> i) const SYCL_ESIMD_KERNEL {
242+
slm_init(B::SLM_CHUNK_SIZE);
243+
244+
// first, read data into SLM
245+
T val = scalar_load<T>(B::acc_in, B::get_wi_offset(i));
246+
slm_scalar_store((unsigned)(B::get_wi_local_offset(i) * sizeof(T)), val);
247+
248+
// wait for peers
249+
esimd_barrier();
250+
251+
// now load from SLM and write back to memory
252+
unsigned wi_local_id = static_cast<unsigned>(i.get_local_id(0));
253+
simd<uint32_t, 1> offsets(wi_local_id * sizeof(T));
254+
simd<T, 1> vec1 = slm_gather<T, 1>(offsets); /*** THE TESTED API ***/
255+
scalar_store(B::acc_out, B::get_wi_offset(i), (T)vec1[0]);
256+
}
257+
};
258+
259+
// Partial specialization of the scatter kernel to test vector length = 1.
260+
template <class T, unsigned STRIDE>
261+
struct ScatterKernel<T, 1, STRIDE, TEST_VECTOR_NO_MASK>
262+
: KernelBase<T, 1, STRIDE> {
263+
using B = KernelBase<T, 1, STRIDE>;
264+
using B::B;
265+
static const char *get_name() { return "slm_scatter_vl1"; }
266+
267+
ESIMD_INLINE void operator()(nd_item<1> i) const SYCL_ESIMD_KERNEL {
268+
slm_init(B::SLM_CHUNK_SIZE);
269+
270+
// first, read data from memory into registers
271+
simd<T, 1> val;
272+
val[0] = scalar_load<T>(B::acc_in, B::get_wi_offset(i));
273+
274+
// now write to SLM
275+
unsigned wi_local_id = static_cast<unsigned>(i.get_local_id(0));
276+
simd<uint32_t, 1> offsets(wi_local_id * sizeof(T));
277+
slm_scatter(val, offsets); /*** THE TESTED API ***/
278+
279+
// wait for peers
280+
esimd_barrier();
281+
282+
// now copy data from SLM back to memory
283+
T v = slm_scalar_load<T>(B::get_wi_local_offset(i) * sizeof(T));
284+
scalar_store(B::acc_out, B::get_wi_offset(i), v);
285+
}
286+
};
287+
232288
enum MemIODir { MEM_SCATTER, MEM_GATHER };
233289

234290
// Verification algorithm depends on whether gather or scatter result is tested.
@@ -287,7 +343,8 @@ static bool verify(T *A, size_t size) {
287343

288344
template <class T, unsigned VL, unsigned STRIDE, MemIODir Dir, TestCase TC>
289345
bool test_impl(queue q) {
290-
size_t size = VL * STRIDE;
346+
size_t size = VL == 1 ? 8 * STRIDE : VL * STRIDE;
347+
291348
using KernelType =
292349
std::conditional_t<Dir == MEM_GATHER, GatherKernel<T, VL, STRIDE, TC>,
293350
ScatterKernel<T, VL, STRIDE, TC>>;
@@ -367,13 +424,22 @@ template <class T, unsigned VL, unsigned STRIDE> bool test(queue q) {
367424
return passed;
368425
}
369426

427+
template <class T, unsigned STRIDE> bool test_vl1(queue q) {
428+
bool passed = true;
429+
std::cout << "\n";
430+
passed &= test_impl<T, 1, STRIDE, MEM_GATHER, TEST_VECTOR_NO_MASK>(q);
431+
passed &= test_impl<T, 1, STRIDE, MEM_SCATTER, TEST_VECTOR_NO_MASK>(q);
432+
return passed;
433+
}
434+
370435
int main(int argc, char **argv) {
371436
queue q(esimd_test::ESIMDSelector{}, esimd_test::createExceptionHandler());
372437

373438
auto dev = q.get_device();
374439
std::cout << "Running on " << dev.get_info<info::device::name>() << "\n";
375440

376441
bool passed = true;
442+
passed &= test_vl1<char, 3>(q);
377443
passed &= test<char, 16, 3>(q);
378444
passed &= test<char, 32, 3>(q);
379445
passed &= test<short, 8, 8>(q);
@@ -387,6 +453,7 @@ int main(int argc, char **argv) {
387453
passed &= test<float, 8, 2>(q);
388454
passed &= test<float, 16, 5>(q);
389455
passed &= test<float, 32, 3>(q);
456+
passed &= test_vl1<float, 7>(q);
390457

391458
std::cout << (!passed ? "TEST FAILED\n" : "TEST Passed\n");
392459
return passed ? 0 : 1;

0 commit comments

Comments
 (0)