Skip to content

Commit 4e6452d

Browse files
authored
[SYCL] Support 3-, 16-elements vectors in SG load/store (#3617)
- when 3-element vector is passed it is packed as single element and 2 element vector; - when 16-element vector is passed it is packed as 2 sequential 8-element vectors. The test has changed in scope of intel/llvm-test-suite#253
1 parent 88939b2 commit 4e6452d

File tree

1 file changed

+92
-51
lines changed

1 file changed

+92
-51
lines changed

sycl/include/CL/sycl/ONEAPI/sub_group.hpp

Lines changed: 92 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -295,29 +295,71 @@ struct sub_group {
295295
PI_INVALID_DEVICE);
296296
#endif
297297
}
298-
298+
#ifdef __SYCL_DEVICE_ONLY__
299+
#ifdef __NVPTX__
299300
template <int N, typename T, access::address_space Space>
300301
sycl::detail::enable_if_t<
301-
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
302-
N != 1,
302+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
303303
vec<T, N>>
304304
load(const multi_ptr<T, Space> src) const {
305-
#ifdef __SYCL_DEVICE_ONLY__
306-
#ifdef __NVPTX__
307305
vec<T, N> res;
308306
for (int i = 0; i < N; ++i) {
309307
res[i] = *(src.get() + i * get_max_local_range()[0] + get_local_id()[0]);
310308
}
311309
return res;
312-
#else
310+
}
311+
#else // __NVPTX__
312+
template <int N, typename T, access::address_space Space>
313+
sycl::detail::enable_if_t<
314+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
315+
N != 1 && N != 3 && N != 16,
316+
vec<T, N>>
317+
load(const multi_ptr<T, Space> src) const {
313318
return sycl::detail::sub_group::load<N, T>(src);
314-
#endif // __NVPTX__
315-
#else
319+
}
320+
321+
template <int N, typename T, access::address_space Space>
322+
sycl::detail::enable_if_t<
323+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
324+
N == 16,
325+
vec<T, 16>>
326+
load(const multi_ptr<T, Space> src) const {
327+
return {sycl::detail::sub_group::load<8, T>(src),
328+
sycl::detail::sub_group::load<8, T>(src +
329+
8 * get_max_local_range()[0])};
330+
}
331+
332+
template <int N, typename T, access::address_space Space>
333+
sycl::detail::enable_if_t<
334+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
335+
N == 3,
336+
vec<T, 3>>
337+
load(const multi_ptr<T, Space> src) const {
338+
return {
339+
sycl::detail::sub_group::load<1, T>(src),
340+
sycl::detail::sub_group::load<2, T>(src + get_max_local_range()[0])};
341+
}
342+
343+
template <int N, typename T, access::address_space Space>
344+
sycl::detail::enable_if_t<
345+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
346+
N == 1,
347+
vec<T, 1>>
348+
load(const multi_ptr<T, Space> src) const {
349+
return sycl::detail::sub_group::load(src);
350+
}
351+
#endif // ___NVPTX___
352+
#else // __SYCL_DEVICE_ONLY__
353+
template <int N, typename T, access::address_space Space>
354+
sycl::detail::enable_if_t<
355+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
356+
vec<T, N>>
357+
load(const multi_ptr<T, Space> src) const {
316358
(void)src;
317359
throw runtime_error("Sub-groups are not supported on host device.",
318360
PI_INVALID_DEVICE);
319-
#endif
320361
}
362+
#endif // __SYCL_DEVICE_ONLY__
321363

322364
template <int N, typename T, access::address_space Space>
323365
sycl::detail::enable_if_t<
@@ -337,25 +379,6 @@ struct sub_group {
337379
#endif
338380
}
339381

340-
template <int N, typename T, access::address_space Space>
341-
sycl::detail::enable_if_t<
342-
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
343-
N == 1,
344-
vec<T, 1>>
345-
load(const multi_ptr<T, Space> src) const {
346-
#ifdef __SYCL_DEVICE_ONLY__
347-
#ifdef __NVPTX__
348-
return src.get()[get_local_id()[0]];
349-
#else
350-
return sycl::detail::sub_group::load(src);
351-
#endif // __NVPTX__
352-
#else
353-
(void)src;
354-
throw runtime_error("Sub-groups are not supported on host device.",
355-
PI_INVALID_DEVICE);
356-
#endif
357-
}
358-
359382
#ifdef __SYCL_DEVICE_ONLY__
360383
// Method for decorated pointer
361384
template <typename T>
@@ -437,45 +460,63 @@ struct sub_group {
437460
#endif
438461
}
439462

463+
#ifdef __SYCL_DEVICE_ONLY__
464+
#ifdef __NVPTX__
465+
template <int N, typename T, access::address_space Space>
466+
sycl::detail::enable_if_t<
467+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value>
468+
store(multi_ptr<T, Space> dst, const vec<T, N> &x) const {
469+
for (int i = 0; i < N; ++i) {
470+
*(dst.get() + i * get_max_local_range()[0] + get_local_id()[0]) = x[i];
471+
}
472+
}
473+
#else // __NVPTX__
474+
template <int N, typename T, access::address_space Space>
475+
sycl::detail::enable_if_t<
476+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
477+
N != 1 && N != 3 && N != 16>
478+
store(multi_ptr<T, Space> dst, const vec<T, N> &x) const {
479+
sycl::detail::sub_group::store(dst, x);
480+
}
481+
440482
template <int N, typename T, access::address_space Space>
441483
sycl::detail::enable_if_t<
442484
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
443485
N == 1>
444486
store(multi_ptr<T, Space> dst, const vec<T, 1> &x) const {
445-
#ifdef __SYCL_DEVICE_ONLY__
446-
#ifdef __NVPTX__
447-
dst.get()[get_local_id()[0]] = x[0];
448-
#else
449-
store<T, Space>(dst, x);
450-
#endif // __NVPTX__
451-
#else
452-
(void)dst;
453-
(void)x;
454-
throw runtime_error("Sub-groups are not supported on host device.",
455-
PI_INVALID_DEVICE);
456-
#endif
487+
sycl::detail::sub_group::store(dst, x);
457488
}
458489

459490
template <int N, typename T, access::address_space Space>
460491
sycl::detail::enable_if_t<
461492
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
462-
N != 1>
463-
store(multi_ptr<T, Space> dst, const vec<T, N> &x) const {
464-
#ifdef __SYCL_DEVICE_ONLY__
465-
#ifdef __NVPTX__
466-
for (int i = 0; i < N; ++i) {
467-
*(dst.get() + i * get_max_local_range()[0] + get_local_id()[0]) = x[i];
468-
}
469-
#else
470-
sycl::detail::sub_group::store(dst, x);
493+
N == 3>
494+
store(multi_ptr<T, Space> dst, const vec<T, 3> &x) const {
495+
store<1, T, Space>(dst, x.s0());
496+
store<2, T, Space>(dst + get_max_local_range()[0], {x.s1(), x.s2()});
497+
}
498+
499+
template <int N, typename T, access::address_space Space>
500+
sycl::detail::enable_if_t<
501+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
502+
N == 16>
503+
store(multi_ptr<T, Space> dst, const vec<T, 16> &x) const {
504+
store<8, T, Space>(dst, x.lo());
505+
store<8, T, Space>(dst + 8 * get_max_local_range()[0], x.hi());
506+
}
507+
471508
#endif // __NVPTX__
472-
#else
509+
#else // __SYCL_DEVICE_ONLY__
510+
template <int N, typename T, access::address_space Space>
511+
sycl::detail::enable_if_t<
512+
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value>
513+
store(multi_ptr<T, Space> dst, const vec<T, N> &x) const {
473514
(void)dst;
474515
(void)x;
475516
throw runtime_error("Sub-groups are not supported on host device.",
476517
PI_INVALID_DEVICE);
477-
#endif
478518
}
519+
#endif // __SYCL_DEVICE_ONLY__
479520

480521
template <int N, typename T, access::address_space Space>
481522
sycl::detail::enable_if_t<

0 commit comments

Comments
 (0)