@@ -295,29 +295,71 @@ struct sub_group {
295
295
PI_INVALID_DEVICE);
296
296
#endif
297
297
}
298
-
298
+ #ifdef __SYCL_DEVICE_ONLY__
299
+ #ifdef __NVPTX__
299
300
template <int N, typename T, access::address_space Space>
300
301
sycl::detail::enable_if_t <
301
- sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
302
- N != 1 ,
302
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
303
303
vec<T, N>>
304
304
load (const multi_ptr<T, Space> src) const {
305
- #ifdef __SYCL_DEVICE_ONLY__
306
- #ifdef __NVPTX__
307
305
vec<T, N> res;
308
306
for (int i = 0 ; i < N; ++i) {
309
307
res[i] = *(src.get () + i * get_max_local_range ()[0 ] + get_local_id ()[0 ]);
310
308
}
311
309
return res;
312
- #else
310
+ }
311
+ #else // __NVPTX__
312
+ template <int N, typename T, access::address_space Space>
313
+ sycl::detail::enable_if_t <
314
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
315
+ N != 1 && N != 3 && N != 16 ,
316
+ vec<T, N>>
317
+ load (const multi_ptr<T, Space> src) const {
313
318
return sycl::detail::sub_group::load<N, T>(src);
314
- #endif // __NVPTX__
315
- #else
319
+ }
320
+
321
+ template <int N, typename T, access::address_space Space>
322
+ sycl::detail::enable_if_t <
323
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
324
+ N == 16 ,
325
+ vec<T, 16 >>
326
+ load (const multi_ptr<T, Space> src) const {
327
+ return {sycl::detail::sub_group::load<8 , T>(src),
328
+ sycl::detail::sub_group::load<8 , T>(src +
329
+ 8 * get_max_local_range ()[0 ])};
330
+ }
331
+
332
+ template <int N, typename T, access::address_space Space>
333
+ sycl::detail::enable_if_t <
334
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
335
+ N == 3 ,
336
+ vec<T, 3 >>
337
+ load (const multi_ptr<T, Space> src) const {
338
+ return {
339
+ sycl::detail::sub_group::load<1 , T>(src),
340
+ sycl::detail::sub_group::load<2 , T>(src + get_max_local_range ()[0 ])};
341
+ }
342
+
343
+ template <int N, typename T, access::address_space Space>
344
+ sycl::detail::enable_if_t <
345
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
346
+ N == 1 ,
347
+ vec<T, 1 >>
348
+ load (const multi_ptr<T, Space> src) const {
349
+ return sycl::detail::sub_group::load (src);
350
+ }
351
+ #endif // ___NVPTX___
352
+ #else // __SYCL_DEVICE_ONLY__
353
+ template <int N, typename T, access::address_space Space>
354
+ sycl::detail::enable_if_t <
355
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
356
+ vec<T, N>>
357
+ load (const multi_ptr<T, Space> src) const {
316
358
(void )src;
317
359
throw runtime_error (" Sub-groups are not supported on host device." ,
318
360
PI_INVALID_DEVICE);
319
- #endif
320
361
}
362
+ #endif // __SYCL_DEVICE_ONLY__
321
363
322
364
template <int N, typename T, access::address_space Space>
323
365
sycl::detail::enable_if_t <
@@ -337,25 +379,6 @@ struct sub_group {
337
379
#endif
338
380
}
339
381
340
- template <int N, typename T, access::address_space Space>
341
- sycl::detail::enable_if_t <
342
- sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
343
- N == 1 ,
344
- vec<T, 1 >>
345
- load (const multi_ptr<T, Space> src) const {
346
- #ifdef __SYCL_DEVICE_ONLY__
347
- #ifdef __NVPTX__
348
- return src.get ()[get_local_id ()[0 ]];
349
- #else
350
- return sycl::detail::sub_group::load (src);
351
- #endif // __NVPTX__
352
- #else
353
- (void )src;
354
- throw runtime_error (" Sub-groups are not supported on host device." ,
355
- PI_INVALID_DEVICE);
356
- #endif
357
- }
358
-
359
382
#ifdef __SYCL_DEVICE_ONLY__
360
383
// Method for decorated pointer
361
384
template <typename T>
@@ -437,45 +460,63 @@ struct sub_group {
437
460
#endif
438
461
}
439
462
463
+ #ifdef __SYCL_DEVICE_ONLY__
464
+ #ifdef __NVPTX__
465
+ template <int N, typename T, access::address_space Space>
466
+ sycl::detail::enable_if_t <
467
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value>
468
+ store (multi_ptr<T, Space> dst, const vec<T, N> &x) const {
469
+ for (int i = 0 ; i < N; ++i) {
470
+ *(dst.get () + i * get_max_local_range ()[0 ] + get_local_id ()[0 ]) = x[i];
471
+ }
472
+ }
473
+ #else // __NVPTX__
474
+ template <int N, typename T, access::address_space Space>
475
+ sycl::detail::enable_if_t <
476
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
477
+ N != 1 && N != 3 && N != 16 >
478
+ store (multi_ptr<T, Space> dst, const vec<T, N> &x) const {
479
+ sycl::detail::sub_group::store (dst, x);
480
+ }
481
+
440
482
template <int N, typename T, access::address_space Space>
441
483
sycl::detail::enable_if_t <
442
484
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
443
485
N == 1 >
444
486
store (multi_ptr<T, Space> dst, const vec<T, 1 > &x) const {
445
- #ifdef __SYCL_DEVICE_ONLY__
446
- #ifdef __NVPTX__
447
- dst.get ()[get_local_id ()[0 ]] = x[0 ];
448
- #else
449
- store<T, Space>(dst, x);
450
- #endif // __NVPTX__
451
- #else
452
- (void )dst;
453
- (void )x;
454
- throw runtime_error (" Sub-groups are not supported on host device." ,
455
- PI_INVALID_DEVICE);
456
- #endif
487
+ sycl::detail::sub_group::store (dst, x);
457
488
}
458
489
459
490
template <int N, typename T, access::address_space Space>
460
491
sycl::detail::enable_if_t <
461
492
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
462
- N != 1 >
463
- store (multi_ptr<T, Space> dst, const vec<T, N> &x) const {
464
- #ifdef __SYCL_DEVICE_ONLY__
465
- #ifdef __NVPTX__
466
- for (int i = 0 ; i < N; ++i) {
467
- *(dst.get () + i * get_max_local_range ()[0 ] + get_local_id ()[0 ]) = x[i];
468
- }
469
- #else
470
- sycl::detail::sub_group::store (dst, x);
493
+ N == 3 >
494
+ store (multi_ptr<T, Space> dst, const vec<T, 3 > &x) const {
495
+ store<1 , T, Space>(dst, x.s0 ());
496
+ store<2 , T, Space>(dst + get_max_local_range ()[0 ], {x.s1 (), x.s2 ()});
497
+ }
498
+
499
+ template <int N, typename T, access::address_space Space>
500
+ sycl::detail::enable_if_t <
501
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
502
+ N == 16 >
503
+ store (multi_ptr<T, Space> dst, const vec<T, 16 > &x) const {
504
+ store<8 , T, Space>(dst, x.lo ());
505
+ store<8 , T, Space>(dst + 8 * get_max_local_range ()[0 ], x.hi ());
506
+ }
507
+
471
508
#endif // __NVPTX__
472
- #else
509
+ #else // __SYCL_DEVICE_ONLY__
510
+ template <int N, typename T, access::address_space Space>
511
+ sycl::detail::enable_if_t <
512
+ sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value>
513
+ store (multi_ptr<T, Space> dst, const vec<T, N> &x) const {
473
514
(void )dst;
474
515
(void )x;
475
516
throw runtime_error (" Sub-groups are not supported on host device." ,
476
517
PI_INVALID_DEVICE);
477
- #endif
478
518
}
519
+ #endif // __SYCL_DEVICE_ONLY__
479
520
480
521
template <int N, typename T, access::address_space Space>
481
522
sycl::detail::enable_if_t <
0 commit comments