Skip to content

Commit 88f0935

Browse files
committed
[SYCL][CUDA] Enable sub-group loads and stores
Rather than implementing the SubgroupBlockReadINTEL and SubgroupBlockWriteINTEL intrinsics in libspirv, this commit adds a fallback path directly to the sub-group header. There are several reasons for this: 1) There are currently no INTEL extensions implemented in libspirv. 2) The load/store functions are expected to be rewritten soon to expose additional functionality, which may map to different SPIR-V. Signed-off-by: John Pennycook <[email protected]>
1 parent 74a68b7 commit 88f0935

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

sycl/include/CL/sycl/ONEAPI/sub_group.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ struct sub_group {
230230
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value, T>
231231
load(const multi_ptr<T, Space> src) const {
232232
#ifdef __SYCL_DEVICE_ONLY__
233+
#ifdef __NVPTX__
234+
return src.get()[get_local_id()[0]];
235+
#else
233236
return sycl::detail::sub_group::load(src);
237+
#endif // __NVPTX__
234238
#else
235239
(void)src;
236240
throw runtime_error("Sub-groups are not supported on host device.",
@@ -258,7 +262,15 @@ struct sub_group {
258262
vec<T, N>>
259263
load(const multi_ptr<T, Space> src) const {
260264
#ifdef __SYCL_DEVICE_ONLY__
265+
#ifdef __NVPTX__
266+
vec<T, N> res;
267+
for (int i = 0; i < N; ++i) {
268+
res[i] = *(src.get() + i * get_max_local_range()[0] + get_local_id()[0]);
269+
}
270+
return res;
271+
#else
261272
return sycl::detail::sub_group::load<N, T>(src);
273+
#endif // __NVPTX__
262274
#else
263275
(void)src;
264276
throw runtime_error("Sub-groups are not supported on host device.",
@@ -291,7 +303,11 @@ struct sub_group {
291303
vec<T, 1>>
292304
load(const multi_ptr<T, Space> src) const {
293305
#ifdef __SYCL_DEVICE_ONLY__
306+
#ifdef __NVPTX__
307+
return src.get()[get_local_id()[0]];
308+
#else
294309
return sycl::detail::sub_group::load(src);
310+
#endif // __NVPTX__
295311
#else
296312
(void)src;
297313
throw runtime_error("Sub-groups are not supported on host device.",
@@ -304,7 +320,11 @@ struct sub_group {
304320
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value>
305321
store(multi_ptr<T, Space> dst, const T &x) const {
306322
#ifdef __SYCL_DEVICE_ONLY__
323+
#ifdef __NVPTX__
324+
dst.get()[get_local_id()[0]] = x;
325+
#else
307326
sycl::detail::sub_group::store(dst, x);
327+
#endif // __NVPTX__
308328
#else
309329
(void)dst;
310330
(void)x;
@@ -333,7 +353,11 @@ struct sub_group {
333353
N == 1>
334354
store(multi_ptr<T, Space> dst, const vec<T, 1> &x) const {
335355
#ifdef __SYCL_DEVICE_ONLY__
356+
#ifdef __NVPTX__
357+
dst.get()[get_local_id()[0]] = x[0];
358+
#else
336359
store<T, Space>(dst, x);
360+
#endif // __NVPTX__
337361
#else
338362
(void)dst;
339363
(void)x;
@@ -348,7 +372,13 @@ struct sub_group {
348372
N != 1>
349373
store(multi_ptr<T, Space> dst, const vec<T, N> &x) const {
350374
#ifdef __SYCL_DEVICE_ONLY__
375+
#ifdef __NVPTX__
376+
for (int i = 0; i < N; ++i) {
377+
*(dst.get() + i * get_max_local_range()[0] + get_local_id()[0]) = x[i];
378+
}
379+
#else
351380
sycl::detail::sub_group::store(dst, x);
381+
#endif // __NVPTX__
352382
#else
353383
(void)dst;
354384
(void)x;

0 commit comments

Comments
 (0)