Skip to content

Commit 0c3c763

Browse files
authored
[SYCL] Infer IdT in GroupBroadcast from group type (#2115)
- Work-group => any integral type - Sub-group => unsigned 32-bit integer Signed-off-by: John Pennycook <[email protected]>
1 parent 80464c3 commit 0c3c763

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,29 @@ template <typename Group> bool GroupAny(bool pred) {
4242
}
4343

4444
// Broadcast with scalar local index
45+
// Work-group supports any integral type
46+
// Sub-group currently supports only uint32_t
4547
template <typename Group, typename T, typename IdT>
46-
detail::enable_if_t<std::is_integral<IdT>::value, T>
48+
detail::enable_if_t<is_group<Group>::value && std::is_integral<IdT>::value, T>
4749
GroupBroadcast(T x, IdT local_id) {
4850
using OCLT = detail::ConvertToOpenCLType_t<T>;
4951
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
5052
OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
5153
OCLIdT ocl_id = detail::convertDataToType<IdT, OCLIdT>(local_id);
5254
return __spirv_GroupBroadcast(group_scope<Group>::value, ocl_x, ocl_id);
5355
}
56+
template <typename Group, typename T, typename IdT>
57+
detail::enable_if_t<is_sub_group<Group>::value && std::is_integral<IdT>::value,
58+
T>
59+
GroupBroadcast(T x, IdT local_id) {
60+
using SGIdT = uint32_t;
61+
SGIdT sg_local_id = static_cast<SGIdT>(local_id);
62+
using OCLT = detail::ConvertToOpenCLType_t<T>;
63+
using OCLIdT = detail::ConvertToOpenCLType_t<SGIdT>;
64+
OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
65+
OCLIdT ocl_id = detail::convertDataToType<SGIdT, OCLIdT>(sg_local_id);
66+
return __spirv_GroupBroadcast(group_scope<Group>::value, ocl_x, ocl_id);
67+
}
5468

5569
// Broadcast with vector local index
5670
template <typename Group, typename T, int Dimensions>

sycl/include/CL/sycl/detail/type_traits.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
__SYCL_INLINE_NAMESPACE(cl) {
1919
namespace sycl {
20+
template <int Dimensions> class group;
21+
namespace intel {
22+
struct sub_group;
23+
} // namespace intel
2024
namespace detail {
2125
namespace half_impl {
2226
class half;
@@ -302,6 +306,20 @@ template <access::address_space AS, class DataT>
302306
using const_if_const_AS = DataT;
303307
#endif
304308

309+
template <typename T> struct is_group : std::false_type {};
310+
311+
template <int Dimensions>
312+
struct is_group<group<Dimensions>> : std::true_type {};
313+
314+
template <typename T> struct is_sub_group : std::false_type {};
315+
316+
template <> struct is_sub_group<intel::sub_group> : std::true_type {};
317+
318+
template <typename T>
319+
struct is_generic_group
320+
: std::integral_constant<bool,
321+
is_group<T>::value || is_sub_group<T>::value> {};
322+
305323
} // namespace detail
306324
} // namespace sycl
307325
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/CL/sycl/intel/group_algorithm.hpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,6 @@ template <> inline id<3> linear_id_to_id(range<3> r, size_t linear_id) {
7777
return result;
7878
}
7979

80-
template <typename T> struct is_group : std::false_type {};
81-
82-
template <int Dimensions>
83-
struct is_group<group<Dimensions>> : std::true_type {};
84-
85-
template <typename T> struct is_sub_group : std::false_type {};
86-
87-
template <> struct is_sub_group<intel::sub_group> : std::true_type {};
88-
89-
template <typename T>
90-
struct is_generic_group
91-
: std::integral_constant<bool,
92-
is_group<T>::value || is_sub_group<T>::value> {};
93-
9480
template <typename T, class BinaryOperation> struct identity {};
9581

9682
template <typename T, typename V> struct identity<T, intel::plus<V>> {

0 commit comments

Comments
 (0)