Skip to content

[SYCL] Widen (u)int8/16 to (u)int32 and half to float in group_broadcast #5110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions sycl/include/CL/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ template <typename Group> bool GroupAny(bool pred) {
}

// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
// FIXME: Do not special-case for half once all backends support all data types.
template <typename T>
using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value>;
using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
!std::is_same<T, half>::value>;

template <typename T, typename IdT = size_t>
using EnableIfNativeBroadcast = detail::enable_if_t<
Expand Down Expand Up @@ -121,6 +123,13 @@ template <typename T, typename IdT = size_t>
using EnableIfGenericBroadcast = detail::enable_if_t<
is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;

// FIXME: Disable widening once all backends support all data types.
template <typename T>
using WidenOpenCLTypeTo32_t = conditional_t<
std::is_same<T, cl_char>() || std::is_same<T, cl_short>(), cl_int,
conditional_t<std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
cl_uint, T>>;

// Broadcast with scalar local index
// Work-group supports any integral type
// Sub-group currently supports only uint32_t
Expand All @@ -133,21 +142,17 @@ EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
using GroupIdT = typename GroupId<Group>::type;
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
using OCLT = detail::ConvertToOpenCLType_t<T>;
using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
}
template <typename Group, typename T, typename IdT>
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
using GroupIdT = typename GroupId<Group>::type;
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
auto BroadcastX = bit_cast<BroadcastT>(x);
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
BroadcastT Result =
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
return bit_cast<T>(Result);
}
template <typename Group, typename T, typename IdT>
Expand All @@ -173,31 +178,21 @@ EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
}
using IdT = vec<size_t, Dimensions>;
using OCLT = detail::ConvertToOpenCLType_t<T>;
using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
IdT VecId;
for (int i = 0; i < Dimensions; ++i) {
VecId[i] = local_id[Dimensions - i - 1];
}
OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
}
template <typename Group, typename T, int Dimensions>
EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
if (Dimensions == 1) {
return GroupBroadcast<Group>(x, local_id[0]);
}
using IdT = vec<size_t, Dimensions>;
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
IdT VecId;
for (int i = 0; i < Dimensions; ++i) {
VecId[i] = local_id[Dimensions - i - 1];
}
auto BroadcastX = bit_cast<BroadcastT>(x);
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
BroadcastT Result =
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
return bit_cast<T>(Result);
}
template <typename Group, typename T, int Dimensions>
Expand Down
Loading