Skip to content

Commit 017af4e

Browse files
authored
[SYCL] Enable algorithm support for sub_group (#1392)
- Adds static members to sub_group class. - sub_group member functions marked deprecated, to be removed later. - SPIR-V helpers expanded to convert SYCL group to SPIR-V scope. - Add workaround for half types Signed-off-by: John Pennycook <[email protected]>
1 parent c98559b commit 017af4e

File tree

10 files changed

+313
-149
lines changed

10 files changed

+313
-149
lines changed

sycl/include/CL/sycl/detail/defines.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@
3333
#endif
3434

3535
#if __cplusplus >= 201402
36-
#define __SYCL_DEPRECATED__ \
37-
[[deprecated("Replaced by in_order queue property")]]
36+
#define __SYCL_DEPRECATED__(message) [[deprecated(message)]]
3837
#elif !defined _MSC_VER
39-
#define __SYCL_DEPRECATED__ \
40-
__attribute__((deprecated("Replaced by in_order queue property")))
38+
#define __SYCL_DEPRECATED__(message) __attribute__((deprecated(message)))
4139
#else
42-
#define __SYCL_DEPRECATED__
40+
#define __SYCL_DEPRECATED__(message)
4341
#endif

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,46 @@
1616
#ifdef __SYCL_DEVICE_ONLY__
1717
__SYCL_INLINE_NAMESPACE(cl) {
1818
namespace sycl {
19+
namespace intel {
20+
struct sub_group;
21+
} // namespace intel
1922
namespace detail {
2023
namespace spirv {
2124

25+
template <typename Group> struct group_scope {};
26+
27+
template <int Dimensions> struct group_scope<group<Dimensions>> {
28+
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Workgroup;
29+
};
30+
31+
template <> struct group_scope<intel::sub_group> {
32+
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
33+
};
34+
35+
template <typename Group> bool GroupAll(bool pred) {
36+
return __spirv_GroupAll(group_scope<Group>::value, pred);
37+
}
38+
39+
template <typename Group> bool GroupAny(bool pred) {
40+
return __spirv_GroupAny(group_scope<Group>::value, pred);
41+
}
42+
2243
// Broadcast with scalar local index
23-
template <__spv::Scope::Flag S, typename T, typename IdT>
44+
template <typename Group, typename T, typename IdT>
2445
detail::enable_if_t<std::is_integral<IdT>::value, T>
2546
GroupBroadcast(T x, IdT local_id) {
2647
using OCLT = detail::ConvertToOpenCLType_t<T>;
2748
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
2849
OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
2950
OCLIdT ocl_id = detail::convertDataToType<IdT, OCLIdT>(local_id);
30-
return __spirv_GroupBroadcast(S, ocl_x, ocl_id);
51+
return __spirv_GroupBroadcast(group_scope<Group>::value, ocl_x, ocl_id);
3152
}
3253

3354
// Broadcast with vector local index
34-
template <__spv::Scope::Flag S, typename T, int Dimensions>
55+
template <typename Group, typename T, int Dimensions>
3556
T GroupBroadcast(T x, id<Dimensions> local_id) {
3657
if (Dimensions == 1) {
37-
return GroupBroadcast<S>(x, local_id[0]);
58+
return GroupBroadcast<Group>(x, local_id[0]);
3859
}
3960
using IdT = vec<size_t, Dimensions>;
4061
using OCLT = detail::ConvertToOpenCLType_t<T>;
@@ -45,7 +66,7 @@ T GroupBroadcast(T x, id<Dimensions> local_id) {
4566
}
4667
OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
4768
OCLIdT ocl_id = detail::convertDataToType<IdT, OCLIdT>(vec_id);
48-
return __spirv_GroupBroadcast(S, ocl_x, ocl_id);
69+
return __spirv_GroupBroadcast(group_scope<Group>::value, ocl_x, ocl_id);
4970
}
5071

5172
} // namespace spirv

0 commit comments

Comments
 (0)