Skip to content

Commit 78f09e6

Browse files
authored
[SYCL] Allow group algorithms to accept a function object with an explicit type (#9784)
This PR also removes unnecessary special cases for half in the static assertions of the group algorithms (and adds corresponding tests for them) and refactors the `reduce_sycl2020.cpp`, `inclusive_scan_sycl2020.cpp`, and `exclusive_scan_sycl2020.cpp` tests a bit.
1 parent 3f4b778 commit 78f09e6

File tree

5 files changed

+239
-252
lines changed

5 files changed

+239
-252
lines changed

sycl/include/sycl/group_algorithm.hpp

Lines changed: 62 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,19 @@ using is_plus_or_multiplies_if_complex = std::integral_constant<
144144
is_multiplies<T, BinaryOperation>::value)
145145
: std::true_type::value)>;
146146

147+
// used to transform a vector op to a scalar op;
148+
// e.g. sycl::plus<std::vec<T, N>> to sycl::plus<T>
149+
template <typename T> struct get_scalar_binary_op;
150+
151+
template <template <typename> typename F, typename T, int n>
152+
struct get_scalar_binary_op<F<sycl::vec<T, n>>> {
153+
using type = F<T>;
154+
};
155+
156+
template <template <typename> typename F> struct get_scalar_binary_op<F<void>> {
157+
using type = F<void>;
158+
};
159+
147160
// ---- identity_for_ga_op
148161
// the group algorithms support std::complex, limited to sycl::plus operation
149162
// get the correct identity for group algorithm operation.
@@ -201,11 +214,8 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
201214
detail::is_native_op<T, BinaryOperation>::value),
202215
T>
203216
reduce_over_group(Group g, T x, BinaryOperation binary_op) {
204-
// FIXME: Do not special-case for half precision
205217
static_assert(
206-
std::is_same_v<decltype(binary_op(x, x)), T> ||
207-
(std::is_same_v<T, half> &&
208-
std::is_same_v<decltype(binary_op(x, x)), float>),
218+
std::is_same_v<decltype(binary_op(x, x)), T>,
209219
"Result type of binary_op must match reduction accumulation type.");
210220
#ifdef __SYCL_DEVICE_ONLY__
211221
#if defined(__NVPTX__)
@@ -251,24 +261,21 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) {
251261
#endif
252262
}
253263

254-
template <typename Group, typename T, int N, class BinaryOperation>
255-
std::enable_if_t<
256-
(is_group_v<std::decay_t<Group>> &&
257-
detail::is_vector_arithmetic_or_complex<sycl::vec<T, N>>::value &&
258-
detail::is_native_op<sycl::vec<T, N>, BinaryOperation>::value),
259-
sycl::vec<T, N>>
260-
reduce_over_group(Group g, sycl::vec<T, N> x, BinaryOperation binary_op) {
261-
// FIXME: Do not special-case for half precision
264+
template <typename Group, typename T, class BinaryOperation>
265+
std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
266+
detail::is_vector_arithmetic_or_complex<T>::value &&
267+
detail::is_native_op<T, BinaryOperation>::value),
268+
T>
269+
reduce_over_group(Group g, T x, BinaryOperation binary_op) {
262270
static_assert(
263-
std::is_same_v<decltype(binary_op(x[0], x[0])),
264-
typename sycl::vec<T, N>::element_type> ||
265-
(std::is_same_v<sycl::vec<T, N>, half> &&
266-
std::is_same_v<decltype(binary_op(x[0], x[0])), float>),
271+
std::is_same_v<decltype(binary_op(x, x)), T>,
267272
"Result type of binary_op must match reduction accumulation type.");
268-
sycl::vec<T, N> result;
269-
270-
detail::loop<N>(
271-
[&](size_t s) { result[s] = reduce_over_group(g, x[s], binary_op); });
273+
T result;
274+
typename detail::get_scalar_binary_op<BinaryOperation>::type
275+
scalar_binary_op{};
276+
detail::loop<x.size()>([&](size_t s) {
277+
result[s] = reduce_over_group(g, x[s], scalar_binary_op);
278+
});
272279
return result;
273280
}
274281

@@ -284,11 +291,8 @@ std::enable_if_t<
284291
std::is_convertible_v<V, T>),
285292
T>
286293
reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
287-
// FIXME: Do not special-case for half precision
288294
static_assert(
289-
std::is_same_v<decltype(binary_op(init, x)), T> ||
290-
(std::is_same_v<T, half> &&
291-
std::is_same_v<decltype(binary_op(init, x)), float>),
295+
std::is_same_v<decltype(binary_op(init, x)), T>,
292296
"Result type of binary_op must match reduction accumulation type.");
293297
#ifdef __SYCL_DEVICE_ONLY__
294298
return binary_op(init, reduce_over_group(g, T(x), binary_op));
@@ -307,17 +311,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
307311
detail::is_native_op<T, BinaryOperation>::value),
308312
T>
309313
reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
310-
// FIXME: Do not special-case for half precision
311314
static_assert(
312-
std::is_same_v<decltype(binary_op(init[0], x[0])),
313-
typename T::element_type> ||
314-
(std::is_same_v<T, half> &&
315-
std::is_same_v<decltype(binary_op(init[0], x[0])), float>),
315+
std::is_same_v<decltype(binary_op(init, x)), T>,
316316
"Result type of binary_op must match reduction accumulation type.");
317+
typename detail::get_scalar_binary_op<BinaryOperation>::type
318+
scalar_binary_op{};
317319
#ifdef __SYCL_DEVICE_ONLY__
318320
T result = init;
319321
for (int s = 0; s < x.size(); ++s) {
320-
result[s] = binary_op(init[s], reduce_over_group(g, x[s], binary_op));
322+
result[s] =
323+
scalar_binary_op(init[s], reduce_over_group(g, x[s], scalar_binary_op));
321324
}
322325
return result;
323326
#else
@@ -338,11 +341,8 @@ std::enable_if_t<
338341
detail::is_native_op<T, BinaryOperation>::value),
339342
T>
340343
joint_reduce(Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
341-
// FIXME: Do not special-case for half precision
342344
static_assert(
343-
std::is_same_v<decltype(binary_op(init, *first)), T> ||
344-
(std::is_same_v<T, half> &&
345-
std::is_same_v<decltype(binary_op(init, *first)), float>),
345+
std::is_same_v<decltype(binary_op(init, *first)), T>,
346346
"Result type of binary_op must match reduction accumulation type.");
347347
#ifdef __SYCL_DEVICE_ONLY__
348348
T partial = detail::identity_for_ga_op<T, BinaryOperation>();
@@ -667,10 +667,7 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
667667
detail::is_native_op<T, BinaryOperation>::value),
668668
T>
669669
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
670-
// FIXME: Do not special-case for half precision
671-
static_assert(std::is_same_v<decltype(binary_op(x, x)), T> ||
672-
(std::is_same_v<T, half> &&
673-
std::is_same_v<decltype(binary_op(x, x)), float>),
670+
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
674671
"Result type of binary_op must match scan accumulation type.");
675672
#ifdef __SYCL_DEVICE_ONLY__
676673
#if defined(__NVPTX__)
@@ -718,15 +715,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
718715
detail::is_native_op<T, BinaryOperation>::value),
719716
T>
720717
exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
721-
// FIXME: Do not special-case for half precision
722-
static_assert(std::is_same_v<decltype(binary_op(x[0], x[0])),
723-
typename T::element_type> ||
724-
(std::is_same_v<T, half> &&
725-
std::is_same_v<decltype(binary_op(x[0], x[0])), float>),
718+
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
726719
"Result type of binary_op must match scan accumulation type.");
727720
T result;
721+
typename detail::get_scalar_binary_op<BinaryOperation>::type
722+
scalar_binary_op{};
728723
for (int s = 0; s < x.size(); ++s) {
729-
result[s] = exclusive_scan_over_group(g, x[s], binary_op);
724+
result[s] = exclusive_scan_over_group(g, x[s], scalar_binary_op);
730725
}
731726
return result;
732727
}
@@ -741,15 +736,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
741736
detail::is_native_op<T, BinaryOperation>::value),
742737
T>
743738
exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
744-
// FIXME: Do not special-case for half precision
745-
static_assert(std::is_same_v<decltype(binary_op(init[0], x[0])),
746-
typename T::element_type> ||
747-
(std::is_same_v<T, half> &&
748-
std::is_same_v<decltype(binary_op(init[0], x[0])), float>),
739+
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
749740
"Result type of binary_op must match scan accumulation type.");
750741
T result;
742+
typename detail::get_scalar_binary_op<BinaryOperation>::type
743+
scalar_binary_op{};
751744
for (int s = 0; s < x.size(); ++s) {
752-
result[s] = exclusive_scan_over_group(g, x[s], init[s], binary_op);
745+
result[s] = exclusive_scan_over_group(g, x[s], init[s], scalar_binary_op);
753746
}
754747
return result;
755748
}
@@ -764,10 +757,7 @@ std::enable_if_t<
764757
std::is_convertible_v<V, T>),
765758
T>
766759
exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
767-
// FIXME: Do not special-case for half precision
768-
static_assert(std::is_same_v<decltype(binary_op(init, x)), T> ||
769-
(std::is_same_v<T, half> &&
770-
std::is_same_v<decltype(binary_op(init, x)), float>),
760+
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
771761
"Result type of binary_op must match scan accumulation type.");
772762
#ifdef __SYCL_DEVICE_ONLY__
773763
typename Group::linear_id_type local_linear_id =
@@ -804,10 +794,7 @@ std::enable_if_t<
804794
OutPtr>
805795
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result, T init,
806796
BinaryOperation binary_op) {
807-
// FIXME: Do not special-case for half precision
808-
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T> ||
809-
(std::is_same_v<T, half> &&
810-
std::is_same_v<decltype(binary_op(init, *first)), float>),
797+
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
811798
"Result type of binary_op must match scan accumulation type.");
812799
#ifdef __SYCL_DEVICE_ONLY__
813800
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
@@ -859,14 +846,9 @@ std::enable_if_t<
859846
OutPtr>
860847
joint_exclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
861848
BinaryOperation binary_op) {
862-
// FIXME: Do not special-case for half precision
863-
static_assert(
864-
std::is_same_v<decltype(binary_op(*first, *first)),
865-
typename detail::remove_pointer<OutPtr>::type> ||
866-
(std::is_same_v<typename detail::remove_pointer<OutPtr>::type,
867-
half> &&
868-
std::is_same_v<decltype(binary_op(*first, *first)), float>),
869-
"Result type of binary_op must match scan accumulation type.");
849+
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
850+
typename detail::remove_pointer<OutPtr>::type>,
851+
"Result type of binary_op must match scan accumulation type.");
870852
using T = typename detail::remove_pointer<OutPtr>::type;
871853
T init = detail::identity_for_ga_op<T, BinaryOperation>();
872854
return joint_exclusive_scan(g, first, last, result, init, binary_op);
@@ -882,15 +864,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
882864
detail::is_native_op<T, BinaryOperation>::value),
883865
T>
884866
inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
885-
// FIXME: Do not special-case for half precision
886-
static_assert(std::is_same_v<decltype(binary_op(x[0], x[0])),
887-
typename T::element_type> ||
888-
(std::is_same_v<T, half> &&
889-
std::is_same_v<decltype(binary_op(x[0], x[0])), float>),
867+
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
890868
"Result type of binary_op must match scan accumulation type.");
891869
T result;
870+
typename detail::get_scalar_binary_op<BinaryOperation>::type
871+
scalar_binary_op{};
892872
for (int s = 0; s < x.size(); ++s) {
893-
result[s] = inclusive_scan_over_group(g, x[s], binary_op);
873+
result[s] = inclusive_scan_over_group(g, x[s], scalar_binary_op);
894874
}
895875
return result;
896876
}
@@ -903,10 +883,7 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
903883
detail::is_native_op<T, BinaryOperation>::value),
904884
T>
905885
inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
906-
// FIXME: Do not special-case for half precision
907-
static_assert(std::is_same_v<decltype(binary_op(x, x)), T> ||
908-
(std::is_same_v<T, half> &&
909-
std::is_same_v<decltype(binary_op(x, x)), float>),
886+
static_assert(std::is_same_v<decltype(binary_op(x, x)), T>,
910887
"Result type of binary_op must match scan accumulation type.");
911888
#ifdef __SYCL_DEVICE_ONLY__
912889
#if defined(__NVPTX__)
@@ -959,10 +936,7 @@ std::enable_if_t<
959936
std::is_convertible_v<V, T>),
960937
T>
961938
inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
962-
// FIXME: Do not special-case for half precision
963-
static_assert(std::is_same_v<decltype(binary_op(init, x)), T> ||
964-
(std::is_same_v<T, half> &&
965-
std::is_same_v<decltype(binary_op(init, x)), float>),
939+
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
966940
"Result type of binary_op must match scan accumulation type.");
967941
#ifdef __SYCL_DEVICE_ONLY__
968942
T y = x;
@@ -985,14 +959,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
985959
detail::is_native_op<T, BinaryOperation>::value),
986960
T>
987961
inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
988-
// FIXME: Do not special-case for half precision
989-
static_assert(std::is_same_v<decltype(binary_op(init[0], x[0])), T> ||
990-
(std::is_same_v<T, half> &&
991-
std::is_same_v<decltype(binary_op(init[0], x[0])), float>),
962+
static_assert(std::is_same_v<decltype(binary_op(init, x)), T>,
992963
"Result type of binary_op must match scan accumulation type.");
993964
T result;
965+
typename detail::get_scalar_binary_op<BinaryOperation>::type
966+
scalar_binary_op{};
994967
for (int s = 0; s < x.size(); ++s) {
995-
result[s] = inclusive_scan_over_group(g, x[s], binary_op, init[s]);
968+
result[s] = inclusive_scan_over_group(g, x[s], scalar_binary_op, init[s]);
996969
}
997970
return result;
998971
}
@@ -1013,10 +986,7 @@ std::enable_if_t<
1013986
OutPtr>
1014987
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
1015988
BinaryOperation binary_op, T init) {
1016-
// FIXME: Do not special-case for half precision
1017-
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T> ||
1018-
(std::is_same_v<T, half> &&
1019-
std::is_same_v<decltype(binary_op(init, *first)), float>),
989+
static_assert(std::is_same_v<decltype(binary_op(init, *first)), T>,
1020990
"Result type of binary_op must match scan accumulation type.");
1021991
#ifdef __SYCL_DEVICE_ONLY__
1022992
ptrdiff_t offset = sycl::detail::get_local_linear_id(g);
@@ -1065,14 +1035,9 @@ std::enable_if_t<
10651035
OutPtr>
10661036
joint_inclusive_scan(Group g, InPtr first, InPtr last, OutPtr result,
10671037
BinaryOperation binary_op) {
1068-
// FIXME: Do not special-case for half precision
1069-
static_assert(
1070-
std::is_same_v<decltype(binary_op(*first, *first)),
1071-
typename detail::remove_pointer<OutPtr>::type> ||
1072-
(std::is_same_v<typename detail::remove_pointer<OutPtr>::type,
1073-
half> &&
1074-
std::is_same_v<decltype(binary_op(*first, *first)), float>),
1075-
"Result type of binary_op must match scan accumulation type.");
1038+
static_assert(std::is_same_v<decltype(binary_op(*first, *first)),
1039+
typename detail::remove_pointer<OutPtr>::type>,
1040+
"Result type of binary_op must match scan accumulation type.");
10761041

10771042
using T = typename detail::remove_pointer<OutPtr>::type;
10781043
T init = detail::identity_for_ga_op<T, BinaryOperation>();

0 commit comments

Comments
 (0)