@@ -86,8 +86,10 @@ template <typename Group> bool GroupAny(bool pred) {
86
86
}
87
87
88
88
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
89
+ // FIXME: Do not special-case for half once all backends support all data types.
89
90
template <typename T>
90
- using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value>;
91
+ using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
92
+ !std::is_same<T, half>::value>;
91
93
92
94
template <typename T, typename IdT = size_t >
93
95
using EnableIfNativeBroadcast = detail::enable_if_t <
@@ -121,6 +123,13 @@ template <typename T, typename IdT = size_t>
121
123
using EnableIfGenericBroadcast = detail::enable_if_t <
122
124
is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;
123
125
126
+ // FIXME: Disable widening once all backends support all data types.
127
+ template <typename T>
128
+ using WidenOpenCLTypeTo32_t = conditional_t <
129
+ std::is_same<T, cl_char>() || std::is_same<T, cl_short>(), cl_int,
130
+ conditional_t <std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
131
+ cl_uint, T>>;
132
+
124
133
// Broadcast with scalar local index
125
134
// Work-group supports any integral type
126
135
// Sub-group currently supports only uint32_t
@@ -133,21 +142,17 @@ EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
133
142
using GroupIdT = typename GroupId<Group>::type;
134
143
GroupIdT GroupLocalId = static_cast <GroupIdT>(local_id);
135
144
using OCLT = detail::ConvertToOpenCLType_t<T>;
145
+ using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
136
146
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
137
- OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
147
+ WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
138
148
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
139
149
return __spirv_GroupBroadcast (group_scope<Group>::value, OCLX, OCLId);
140
150
}
141
151
template <typename Group, typename T, typename IdT>
142
152
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast (T x, IdT local_id) {
143
- using GroupIdT = typename GroupId<Group>::type;
144
- GroupIdT GroupLocalId = static_cast <GroupIdT>(local_id);
145
153
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
146
- using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
147
154
auto BroadcastX = bit_cast<BroadcastT>(x);
148
- OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
149
- BroadcastT Result =
150
- __spirv_GroupBroadcast (group_scope<Group>::value, BroadcastX, OCLId);
155
+ BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
151
156
return bit_cast<T>(Result);
152
157
}
153
158
template <typename Group, typename T, typename IdT>
@@ -173,31 +178,21 @@ EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
173
178
}
174
179
using IdT = vec<size_t , Dimensions>;
175
180
using OCLT = detail::ConvertToOpenCLType_t<T>;
181
+ using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
176
182
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
177
183
IdT VecId;
178
184
for (int i = 0 ; i < Dimensions; ++i) {
179
185
VecId[i] = local_id[Dimensions - i - 1 ];
180
186
}
181
- OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
187
+ WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
182
188
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
183
189
return __spirv_GroupBroadcast (group_scope<Group>::value, OCLX, OCLId);
184
190
}
185
191
template <typename Group, typename T, int Dimensions>
186
192
EnableIfBitcastBroadcast<T> GroupBroadcast (T x, id<Dimensions> local_id) {
187
- if (Dimensions == 1 ) {
188
- return GroupBroadcast<Group>(x, local_id[0 ]);
189
- }
190
- using IdT = vec<size_t , Dimensions>;
191
193
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
192
- using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
193
- IdT VecId;
194
- for (int i = 0 ; i < Dimensions; ++i) {
195
- VecId[i] = local_id[Dimensions - i - 1 ];
196
- }
197
194
auto BroadcastX = bit_cast<BroadcastT>(x);
198
- OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
199
- BroadcastT Result =
200
- __spirv_GroupBroadcast (group_scope<Group>::value, BroadcastX, OCLId);
195
+ BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
201
196
return bit_cast<T>(Result);
202
197
}
203
198
template <typename Group, typename T, int Dimensions>
0 commit comments