Skip to content

Commit 1f3f9b9

Browse files
authored
[SYCL] Widen (u)int8/16 to (u)int32 and half to float in group_broadcast (#5110)
CPU device does not yet support the (u)int8/16 and half versions. - Add FIXMEs. - Bitcast half to int16_t (and then widen to int32_t) to keep the precision. - Refactor the widening code into a separate helper. - Add tests for all 3 group_broadcast overloads.
1 parent e83cb19 commit 1f3f9b9

File tree

2 files changed

+569
-438
lines changed

2 files changed

+569
-438
lines changed

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ template <typename Group> bool GroupAny(bool pred) {
8686
}
8787

8888
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
89+
// FIXME: Do not special-case for half once all backends support all data types.
8990
template <typename T>
90-
using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value>;
91+
using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
92+
!std::is_same<T, half>::value>;
9193

9294
template <typename T, typename IdT = size_t>
9395
using EnableIfNativeBroadcast = detail::enable_if_t<
@@ -121,6 +123,13 @@ template <typename T, typename IdT = size_t>
121123
using EnableIfGenericBroadcast = detail::enable_if_t<
122124
is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;
123125

126+
// FIXME: Disable widening once all backends support all data types.
127+
template <typename T>
128+
using WidenOpenCLTypeTo32_t = conditional_t<
129+
std::is_same<T, cl_char>() || std::is_same<T, cl_short>(), cl_int,
130+
conditional_t<std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
131+
cl_uint, T>>;
132+
124133
// Broadcast with scalar local index
125134
// Work-group supports any integral type
126135
// Sub-group currently supports only uint32_t
@@ -133,21 +142,17 @@ EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
133142
using GroupIdT = typename GroupId<Group>::type;
134143
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
135144
using OCLT = detail::ConvertToOpenCLType_t<T>;
145+
using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
136146
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
137-
OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
147+
WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
138148
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
139149
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
140150
}
141151
template <typename Group, typename T, typename IdT>
142152
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
143-
using GroupIdT = typename GroupId<Group>::type;
144-
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
145153
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
146-
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
147154
auto BroadcastX = bit_cast<BroadcastT>(x);
148-
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
149-
BroadcastT Result =
150-
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
155+
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
151156
return bit_cast<T>(Result);
152157
}
153158
template <typename Group, typename T, typename IdT>
@@ -173,31 +178,21 @@ EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
173178
}
174179
using IdT = vec<size_t, Dimensions>;
175180
using OCLT = detail::ConvertToOpenCLType_t<T>;
181+
using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
176182
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
177183
IdT VecId;
178184
for (int i = 0; i < Dimensions; ++i) {
179185
VecId[i] = local_id[Dimensions - i - 1];
180186
}
181-
OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
187+
WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
182188
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
183189
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
184190
}
185191
template <typename Group, typename T, int Dimensions>
186192
EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
187-
if (Dimensions == 1) {
188-
return GroupBroadcast<Group>(x, local_id[0]);
189-
}
190-
using IdT = vec<size_t, Dimensions>;
191193
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
192-
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
193-
IdT VecId;
194-
for (int i = 0; i < Dimensions; ++i) {
195-
VecId[i] = local_id[Dimensions - i - 1];
196-
}
197194
auto BroadcastX = bit_cast<BroadcastT>(x);
198-
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
199-
BroadcastT Result =
200-
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
195+
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
201196
return bit_cast<T>(Result);
202197
}
203198
template <typename Group, typename T, int Dimensions>

0 commit comments

Comments
 (0)