Skip to content

[SYCL] Allow group algorithms to accept a function object with an explicit type #9784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 62 additions & 97 deletions sycl/include/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ using is_plus_or_multiplies_if_complex = std::integral_constant<
is_multiplies<T, BinaryOperation>::value)
: std::true_type::value)>;

// used to transform a vector op to a scalar op;
// e.g. sycl::plus<std::vec<T, N>> to sycl::plus<T>
template <typename T> struct get_scalar_binary_op;

template <template <typename> typename F, typename T, int n>
struct get_scalar_binary_op<F<sycl::vec<T, n>>> {
using type = F<T>;
};

template <template <typename> typename F> struct get_scalar_binary_op<F<void>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this template specialization needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To handle the function types without an explicit argument, e.g. in a case like reduce_over_group(g, int2{}, sycl::plus<>()), we will be instantiating get_scalar_binary_op with sycl::plus<void> (= sycl::plus<>) , which does not match the other partial specialization, so we need this to compile correctly.

using type = F<void>;
};

// ---- identity_for_ga_op
// the group algorithms support std::complex, limited to sycl::plus operation
// get the correct identity for group algorithm operation.
Expand Down Expand Up @@ -200,11 +213,8 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
reduce_over_group(Group g, T x, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(
std::is_same_v<decltype(binary_op(x, x)), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(x, x)), float>),
std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match reduction accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
return sycl::detail::calc<__spv::GroupOperation::Reduce>(
Expand Down Expand Up @@ -239,24 +249,21 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) {
#endif
}

template <typename Group, typename T, int N, class BinaryOperation>
std::enable_if_t<
(is_group_v<std::decay_t<Group>> &&
detail::is_vector_arithmetic_or_complex<sycl::vec<T, N>>::value &&
detail::is_native_op<sycl::vec<T, N>, BinaryOperation>::value),
sycl::vec<T, N>>
reduce_over_group(Group g, sycl::vec<T, N> x, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
template <typename Group, typename T, class BinaryOperation>
std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_vector_arithmetic_or_complex<T>::value &&
detail::is_native_op<T, BinaryOperation>::value),
T>
reduce_over_group(Group g, T x, BinaryOperation binary_op) {
static_assert(
std::is_same_v<decltype(binary_op(x[0], x[0])),
typename sycl::vec<T, N>::element_type> ||
(std::is_same_v<sycl::vec<T, N>, half> &&
std::is_same_v<decltype(binary_op(x[0], x[0])), float>),
std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match reduction accumulation type.");
sycl::vec<T, N> result;

detail::loop<N>(
[&](size_t s) { result[s] = reduce_over_group(g, x[s], binary_op); });
T result;
typename detail::get_scalar_binary_op<BinaryOperation>::type
scalar_binary_op{};
detail::loop<x.size()>([&](size_t s) {
result[s] = reduce_over_group(g, x[s], scalar_binary_op);
});
return result;
}

Expand All @@ -272,11 +279,8 @@ std::enable_if_t<
std::is_convertible_v<V, T>),
T>
reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(
std::is_same_v<decltype(binary_op(init, x)), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(init, x)), float>),
std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match reduction accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
return binary_op(init, reduce_over_group(g, T(x), binary_op));
Expand All @@ -295,17 +299,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(
std::is_same_v<decltype(binary_op(init[0], x[0])),
typename T::element_type> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(init[0], x[0])), float>),
std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match reduction accumulation type.");
typename detail::get_scalar_binary_op<BinaryOperation>::type
scalar_binary_op{};
#ifdef __SYCL_DEVICE_ONLY__
T result = init;
for (int s = 0; s < x.size(); ++s) {
result[s] = binary_op(init[s], reduce_over_group(g, x[s], binary_op));
result[s] =
scalar_binary_op(init[s], reduce_over_group(g, x[s], scalar_binary_op));
}
return result;
#else
Expand All @@ -326,11 +329,8 @@ std::enable_if_t<
detail::is_native_op<T, BinaryOperation>::value),
T>
joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(
std::is_same_v<decltype(binary_op(init, *first)), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(init, *first)), float>),
std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match reduction accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
T partial = detail::identity_for_ga_op<T, BinaryOperation>();
Expand Down Expand Up @@ -630,10 +630,7 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(x, x)), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(x, x)), float>),
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
return sycl::detail::calc<__spv::GroupOperation::ExclusiveScan>(
Expand Down Expand Up @@ -674,15 +671,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(x[0], x[0])),
typename T::element_type> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(x[0], x[0])), float>),
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
T result;
typename detail::get_scalar_binary_op<BinaryOperation>::type
scalar_binary_op{};
for (int s = 0; s < x.size(); ++s) {
result[s] = exclusive_scan_over_group(g, x[s], binary_op);
result[s] = exclusive_scan_over_group(g, x[s], scalar_binary_op);
}
return result;
}
Expand All @@ -697,15 +692,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(init[0], x[0])),
typename T::element_type> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(init[0], x[0])), float>),
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
T result;
typename detail::get_scalar_binary_op<BinaryOperation>::type
scalar_binary_op{};
for (int s = 0; s < x.size(); ++s) {
result[s] = exclusive_scan_over_group(g, x[s], init[s], binary_op);
result[s] = exclusive_scan_over_group(g, x[s], init[s], scalar_binary_op);
}
return result;
}
Expand All @@ -720,10 +713,7 @@ std::enable_if_t<
std::is_convertible_v<V, T>),
T>
exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(init, x)), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(init, x)), float>),
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
typename Group::linear_id_type local_linear_id =
Expand Down Expand Up @@ -760,10 +750,7 @@ std::enable_if_t<
OutPtr>
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(init, *first)), float>),
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
Expand Down Expand Up @@ -815,14 +802,9 @@ std::enable_if_t<
OutPtr>
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(
std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type> ||
(std::is_same_v<typename detail::remove_pointer<OutPtr>::type,
half> &&
std::is_same_v<decltype(binary_op(*first, *first)), float>),
"Result type of binary_op must match scan accumulation type.");
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type>,
"Result type of binary_op must match scan accumulation type.");
using T = typename detail::remove_pointer<OutPtr>::type;
T init = detail::identity_for_ga_op<T, BinaryOperation>();
return joint_exclusive_scan(g, first, last, result, init, binary_op);
Expand All @@ -838,15 +820,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(x[0], x[0])),
typename T::element_type> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(x[0], x[0])), float>),
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
T result;
typename detail::get_scalar_binary_op<BinaryOperation>::type
scalar_binary_op{};
for (int s = 0; s < x.size(); ++s) {
result[s] = inclusive_scan_over_group(g, x[s], binary_op);
result[s] = inclusive_scan_over_group(g, x[s], scalar_binary_op);
}
return result;
}
Expand All @@ -859,10 +839,7 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(x, x)), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(x, x)), float>),
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
return sycl::detail::calc<__spv::GroupOperation::InclusiveScan>(
Expand Down Expand Up @@ -908,10 +885,7 @@ std::enable_if_t<
std::is_convertible_v<V, T>),
T>
inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(init, x)), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(init, x)), float>),
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
T y = x;
Expand All @@ -934,14 +908,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
detail::is_native_op<T, BinaryOperation>::value),
T>
inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(init[0], x[0])), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(init[0], x[0])), float>),
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
"Result type of binary_op must match scan accumulation type.");
T result;
typename detail::get_scalar_binary_op<BinaryOperation>::type
scalar_binary_op{};
for (int s = 0; s < x.size(); ++s) {
result[s] = inclusive_scan_over_group(g, x[s], binary_op, init[s]);
result[s] = inclusive_scan_over_group(g, x[s], scalar_binary_op, init[s]);
}
return result;
}
Expand All @@ -962,10 +935,7 @@ std::enable_if_t<
OutPtr>
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
BinaryOperation binary_op, T init) {
// FIXME: Do not special-case for half precision
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T> ||
(std::is_same_v<T, half> &&
std::is_same_v<decltype(binary_op(init, *first)), float>),
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
"Result type of binary_op must match scan accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
Expand Down Expand Up @@ -1014,14 +984,9 @@ std::enable_if_t<
OutPtr>
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
BinaryOperation binary_op) {
// FIXME: Do not special-case for half precision
static_assert(
std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type> ||
(std::is_same_v<typename detail::remove_pointer<OutPtr>::type,
half> &&
std::is_same_v<decltype(binary_op(*first, *first)), float>),
"Result type of binary_op must match scan accumulation type.");
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
typename detail::remove_pointer<OutPtr>::type>,
"Result type of binary_op must match scan accumulation type.");

using T = typename detail::remove_pointer<OutPtr>::type;
T init = detail::identity_for_ga_op<T, BinaryOperation>();
Expand Down
Loading