Skip to content

[SYCL] Generalize group_algorithm helpers #12726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions sycl/include/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,42 +59,25 @@ template <> inline id<3> linear_id_to_id(range<3> r, size_t linear_id) {
}

// ---- get_local_linear_range
template <typename Group> size_t get_local_linear_range(Group g);
template <> inline size_t get_local_linear_range<group<1>>(group<1> g) {
return g.get_local_range(0);
}
template <> inline size_t get_local_linear_range<group<2>>(group<2> g) {
return g.get_local_range(0) * g.get_local_range(1);
}
template <> inline size_t get_local_linear_range<group<3>>(group<3> g) {
return g.get_local_range(0) * g.get_local_range(1) * g.get_local_range(2);
}
template <>
inline size_t get_local_linear_range<sycl::sub_group>(sycl::sub_group g) {
return g.get_local_range()[0];
template <typename Group> inline auto get_local_linear_range(Group g) {
auto local_range = g.get_local_range();
auto result = local_range[0];
for (size_t i = 1; i < Group::dimensions; ++i)
result *= local_range[i];
return result;
}

// ---- get_local_linear_id
template <typename Group>
inline typename Group::linear_id_type get_local_linear_id(Group g);

template <typename Group> inline auto get_local_linear_id(Group g) {
#ifdef __SYCL_DEVICE_ONLY__
#define __SYCL_GROUP_GET_LOCAL_LINEAR_ID(D) \
template <> \
inline group<D>::linear_id_type get_local_linear_id<group<D>>(group<D>) { \
nd_item<D> it = sycl::detail::Builder::getNDItem<D>(); \
return it.get_local_linear_id(); \
if constexpr (std::is_same_v<Group, group<1>> ||
std::is_same_v<Group, group<2>> ||
std::is_same_v<Group, group<3>>) {
auto it = sycl::detail::Builder::getNDItem<Group::dimensions>();
return it.get_local_linear_id();
}
__SYCL_GROUP_GET_LOCAL_LINEAR_ID(1);
__SYCL_GROUP_GET_LOCAL_LINEAR_ID(2);
__SYCL_GROUP_GET_LOCAL_LINEAR_ID(3);
#undef __SYCL_GROUP_GET_LOCAL_LINEAR_ID
#endif // __SYCL_DEVICE_ONLY__

template <>
inline sycl::sub_group::linear_id_type
get_local_linear_id<sycl::sub_group>(sycl::sub_group g) {
return g.get_local_id()[0];
return g.get_local_linear_id();
}

// ---- is_native_op
Expand Down