@@ -229,6 +229,62 @@ struct ScatterKernel : KernelBase<T, VL, STRIDE> {
229
229
}
230
230
};
231
231
232
+ // Partial specialization of the gather kernel to test vector length = 1.
233
+ template <class T , unsigned STRIDE>
234
+ struct GatherKernel <T, 1 , STRIDE, TEST_VECTOR_NO_MASK>
235
+ : KernelBase<T, 1 , STRIDE> {
236
+ using B = KernelBase<T, 1 , STRIDE>;
237
+ using B::B;
238
+
239
+ static const char *get_name () { return " slm_gather_vl1" ; }
240
+
241
+ void operator ()(nd_item<1 > i) const SYCL_ESIMD_KERNEL {
242
+ slm_init (B::SLM_CHUNK_SIZE);
243
+
244
+ // first, read data into SLM
245
+ T val = scalar_load<T>(B::acc_in, B::get_wi_offset (i));
246
+ slm_scalar_store ((unsigned )(B::get_wi_local_offset (i) * sizeof (T)), val);
247
+
248
+ // wait for peers
249
+ esimd_barrier ();
250
+
251
+ // now load from SLM and write back to memory
252
+ unsigned wi_local_id = static_cast <unsigned >(i.get_local_id (0 ));
253
+ simd<uint32_t , 1 > offsets (wi_local_id * sizeof (T));
254
+ simd<T, 1 > vec1 = slm_gather<T, 1 >(offsets); /* ** THE TESTED API ***/
255
+ scalar_store (B::acc_out, B::get_wi_offset (i), (T)vec1[0 ]);
256
+ }
257
+ };
258
+
259
+ // Partial specialization of the scatter kernel to test vector length = 1.
260
+ template <class T , unsigned STRIDE>
261
+ struct ScatterKernel <T, 1 , STRIDE, TEST_VECTOR_NO_MASK>
262
+ : KernelBase<T, 1 , STRIDE> {
263
+ using B = KernelBase<T, 1 , STRIDE>;
264
+ using B::B;
265
+ static const char *get_name () { return " slm_scatter_vl1" ; }
266
+
267
+ ESIMD_INLINE void operator ()(nd_item<1 > i) const SYCL_ESIMD_KERNEL {
268
+ slm_init (B::SLM_CHUNK_SIZE);
269
+
270
+ // first, read data from memory into registers
271
+ simd<T, 1 > val;
272
+ val[0 ] = scalar_load<T>(B::acc_in, B::get_wi_offset (i));
273
+
274
+ // now write to SLM
275
+ unsigned wi_local_id = static_cast <unsigned >(i.get_local_id (0 ));
276
+ simd<uint32_t , 1 > offsets (wi_local_id * sizeof (T));
277
+ slm_scatter (val, offsets); /* ** THE TESTED API ***/
278
+
279
+ // wait for peers
280
+ esimd_barrier ();
281
+
282
+ // now copy data from SLM back to memory
283
+ T v = slm_scalar_load<T>(B::get_wi_local_offset (i) * sizeof (T));
284
+ scalar_store (B::acc_out, B::get_wi_offset (i), v);
285
+ }
286
+ };
287
+
232
288
enum MemIODir { MEM_SCATTER, MEM_GATHER };
233
289
234
290
// Verification algorithm depends on whether gather or scatter result is tested.
@@ -287,7 +343,8 @@ static bool verify(T *A, size_t size) {
287
343
288
344
template <class T , unsigned VL, unsigned STRIDE, MemIODir Dir, TestCase TC>
289
345
bool test_impl (queue q) {
290
- size_t size = VL * STRIDE;
346
+ size_t size = VL == 1 ? 8 * STRIDE : VL * STRIDE;
347
+
291
348
using KernelType =
292
349
std::conditional_t <Dir == MEM_GATHER, GatherKernel<T, VL, STRIDE, TC>,
293
350
ScatterKernel<T, VL, STRIDE, TC>>;
@@ -367,13 +424,22 @@ template <class T, unsigned VL, unsigned STRIDE> bool test(queue q) {
367
424
return passed;
368
425
}
369
426
427
+ template <class T , unsigned STRIDE> bool test_vl1 (queue q) {
428
+ bool passed = true ;
429
+ std::cout << " \n " ;
430
+ passed &= test_impl<T, 1 , STRIDE, MEM_GATHER, TEST_VECTOR_NO_MASK>(q);
431
+ passed &= test_impl<T, 1 , STRIDE, MEM_SCATTER, TEST_VECTOR_NO_MASK>(q);
432
+ return passed;
433
+ }
434
+
370
435
int main (int argc, char **argv) {
371
436
queue q (esimd_test::ESIMDSelector{}, esimd_test::createExceptionHandler ());
372
437
373
438
auto dev = q.get_device ();
374
439
std::cout << " Running on " << dev.get_info <info::device::name>() << " \n " ;
375
440
376
441
bool passed = true ;
442
+ passed &= test_vl1<char , 3 >(q);
377
443
passed &= test<char , 16 , 3 >(q);
378
444
passed &= test<char , 32 , 3 >(q);
379
445
passed &= test<short , 8 , 8 >(q);
@@ -387,6 +453,7 @@ int main(int argc, char **argv) {
387
453
passed &= test<float , 8 , 2 >(q);
388
454
passed &= test<float , 16 , 5 >(q);
389
455
passed &= test<float , 32 , 3 >(q);
456
+ passed &= test_vl1<float , 7 >(q);
390
457
391
458
std::cout << (!passed ? " TEST FAILED\n " : " TEST Passed\n " );
392
459
return passed ? 0 : 1 ;
0 commit comments