@@ -144,6 +144,19 @@ using is_plus_or_multiplies_if_complex = std::integral_constant<
144
144
is_multiplies<T, BinaryOperation>::value)
145
145
: std::true_type::value)>;
146
146
147
+ // used to transform a vector op to a scalar op;
148
+ // e.g. sycl::plus<std::vec<T, N>> to sycl::plus<T>
149
+ template <typename T> struct get_scalar_binary_op ;
150
+
151
+ template <template <typename > typename F, typename T, int n>
152
+ struct get_scalar_binary_op <F<sycl::vec<T, n>>> {
153
+ using type = F<T>;
154
+ };
155
+
156
+ template <template <typename > typename F> struct get_scalar_binary_op <F<void >> {
157
+ using type = F<void >;
158
+ };
159
+
147
160
// ---- identity_for_ga_op
148
161
// the group algorithms support std::complex, limited to sycl::plus operation
149
162
// get the correct identity for group algorithm operation.
@@ -201,11 +214,8 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
201
214
detail::is_native_op<T, BinaryOperation>::value),
202
215
T>
203
216
reduce_over_group (Group g, T x, BinaryOperation binary_op) {
204
- // FIXME: Do not special-case for half precision
205
217
static_assert (
206
- std::is_same_v<decltype (binary_op (x, x)), T> ||
207
- (std::is_same_v<T, half> &&
208
- std::is_same_v<decltype (binary_op (x, x)), float >),
218
+ std::is_same_v<decltype (binary_op (x, x)), T>,
209
219
" Result type of binary_op must match reduction accumulation type." );
210
220
#ifdef __SYCL_DEVICE_ONLY__
211
221
#if defined(__NVPTX__)
@@ -251,24 +261,21 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) {
251
261
#endif
252
262
}
253
263
254
- template <typename Group, typename T, int N, class BinaryOperation >
255
- std::enable_if_t <
256
- (is_group_v<std::decay_t <Group>> &&
257
- detail::is_vector_arithmetic_or_complex<sycl::vec<T, N>>::value &&
258
- detail::is_native_op<sycl::vec<T, N>, BinaryOperation>::value),
259
- sycl::vec<T, N>>
260
- reduce_over_group (Group g, sycl::vec<T, N> x, BinaryOperation binary_op) {
261
- // FIXME: Do not special-case for half precision
264
+ template <typename Group, typename T, class BinaryOperation >
265
+ std::enable_if_t <(is_group_v<std::decay_t <Group>> &&
266
+ detail::is_vector_arithmetic_or_complex<T>::value &&
267
+ detail::is_native_op<T, BinaryOperation>::value),
268
+ T>
269
+ reduce_over_group (Group g, T x, BinaryOperation binary_op) {
262
270
static_assert (
263
- std::is_same_v<decltype (binary_op (x[0 ], x[0 ])),
264
- typename sycl::vec<T, N>::element_type> ||
265
- (std::is_same_v<sycl::vec<T, N>, half> &&
266
- std::is_same_v<decltype (binary_op (x[0 ], x[0 ])), float >),
271
+ std::is_same_v<decltype (binary_op (x, x)), T>,
267
272
" Result type of binary_op must match reduction accumulation type." );
268
- sycl::vec<T, N> result;
269
-
270
- detail::loop<N>(
271
- [&](size_t s) { result[s] = reduce_over_group (g, x[s], binary_op); });
273
+ T result;
274
+ typename detail::get_scalar_binary_op<BinaryOperation>::type
275
+ scalar_binary_op{};
276
+ detail::loop<x.size ()>([&](size_t s) {
277
+ result[s] = reduce_over_group (g, x[s], scalar_binary_op);
278
+ });
272
279
return result;
273
280
}
274
281
@@ -284,11 +291,8 @@ std::enable_if_t<
284
291
std::is_convertible_v<V, T>),
285
292
T>
286
293
reduce_over_group (Group g, V x, T init, BinaryOperation binary_op) {
287
- // FIXME: Do not special-case for half precision
288
294
static_assert (
289
- std::is_same_v<decltype (binary_op (init, x)), T> ||
290
- (std::is_same_v<T, half> &&
291
- std::is_same_v<decltype (binary_op (init, x)), float >),
295
+ std::is_same_v<decltype (binary_op (init, x)), T>,
292
296
" Result type of binary_op must match reduction accumulation type." );
293
297
#ifdef __SYCL_DEVICE_ONLY__
294
298
return binary_op (init, reduce_over_group (g, T (x), binary_op));
@@ -307,17 +311,16 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
307
311
detail::is_native_op<T, BinaryOperation>::value),
308
312
T>
309
313
reduce_over_group (Group g, V x, T init, BinaryOperation binary_op) {
310
- // FIXME: Do not special-case for half precision
311
314
static_assert (
312
- std::is_same_v<decltype (binary_op (init[0 ], x[0 ])),
313
- typename T::element_type> ||
314
- (std::is_same_v<T, half> &&
315
- std::is_same_v<decltype (binary_op (init[0 ], x[0 ])), float >),
315
+ std::is_same_v<decltype (binary_op (init, x)), T>,
316
316
" Result type of binary_op must match reduction accumulation type." );
317
+ typename detail::get_scalar_binary_op<BinaryOperation>::type
318
+ scalar_binary_op{};
317
319
#ifdef __SYCL_DEVICE_ONLY__
318
320
T result = init;
319
321
for (int s = 0 ; s < x.size (); ++s) {
320
- result[s] = binary_op (init[s], reduce_over_group (g, x[s], binary_op));
322
+ result[s] =
323
+ scalar_binary_op (init[s], reduce_over_group (g, x[s], scalar_binary_op));
321
324
}
322
325
return result;
323
326
#else
@@ -338,11 +341,8 @@ std::enable_if_t<
338
341
detail::is_native_op<T, BinaryOperation>::value),
339
342
T>
340
343
joint_reduce (Group g, Ptr first, Ptr last, T init, BinaryOperation binary_op) {
341
- // FIXME: Do not special-case for half precision
342
344
static_assert (
343
- std::is_same_v<decltype (binary_op (init, *first)), T> ||
344
- (std::is_same_v<T, half> &&
345
- std::is_same_v<decltype (binary_op (init, *first)), float >),
345
+ std::is_same_v<decltype (binary_op (init, *first)), T>,
346
346
" Result type of binary_op must match reduction accumulation type." );
347
347
#ifdef __SYCL_DEVICE_ONLY__
348
348
T partial = detail::identity_for_ga_op<T, BinaryOperation>();
@@ -667,10 +667,7 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
667
667
detail::is_native_op<T, BinaryOperation>::value),
668
668
T>
669
669
exclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
670
- // FIXME: Do not special-case for half precision
671
- static_assert (std::is_same_v<decltype (binary_op (x, x)), T> ||
672
- (std::is_same_v<T, half> &&
673
- std::is_same_v<decltype (binary_op (x, x)), float >),
670
+ static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
674
671
" Result type of binary_op must match scan accumulation type." );
675
672
#ifdef __SYCL_DEVICE_ONLY__
676
673
#if defined(__NVPTX__)
@@ -718,15 +715,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
718
715
detail::is_native_op<T, BinaryOperation>::value),
719
716
T>
720
717
exclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
721
- // FIXME: Do not special-case for half precision
722
- static_assert (std::is_same_v<decltype (binary_op (x[0 ], x[0 ])),
723
- typename T::element_type> ||
724
- (std::is_same_v<T, half> &&
725
- std::is_same_v<decltype (binary_op (x[0 ], x[0 ])), float >),
718
+ static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
726
719
" Result type of binary_op must match scan accumulation type." );
727
720
T result;
721
+ typename detail::get_scalar_binary_op<BinaryOperation>::type
722
+ scalar_binary_op{};
728
723
for (int s = 0 ; s < x.size (); ++s) {
729
- result[s] = exclusive_scan_over_group (g, x[s], binary_op );
724
+ result[s] = exclusive_scan_over_group (g, x[s], scalar_binary_op );
730
725
}
731
726
return result;
732
727
}
@@ -741,15 +736,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
741
736
detail::is_native_op<T, BinaryOperation>::value),
742
737
T>
743
738
exclusive_scan_over_group (Group g, V x, T init, BinaryOperation binary_op) {
744
- // FIXME: Do not special-case for half precision
745
- static_assert (std::is_same_v<decltype (binary_op (init[0 ], x[0 ])),
746
- typename T::element_type> ||
747
- (std::is_same_v<T, half> &&
748
- std::is_same_v<decltype (binary_op (init[0 ], x[0 ])), float >),
739
+ static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
749
740
" Result type of binary_op must match scan accumulation type." );
750
741
T result;
742
+ typename detail::get_scalar_binary_op<BinaryOperation>::type
743
+ scalar_binary_op{};
751
744
for (int s = 0 ; s < x.size (); ++s) {
752
- result[s] = exclusive_scan_over_group (g, x[s], init[s], binary_op );
745
+ result[s] = exclusive_scan_over_group (g, x[s], init[s], scalar_binary_op );
753
746
}
754
747
return result;
755
748
}
@@ -764,10 +757,7 @@ std::enable_if_t<
764
757
std::is_convertible_v<V, T>),
765
758
T>
766
759
exclusive_scan_over_group (Group g, V x, T init, BinaryOperation binary_op) {
767
- // FIXME: Do not special-case for half precision
768
- static_assert (std::is_same_v<decltype (binary_op (init, x)), T> ||
769
- (std::is_same_v<T, half> &&
770
- std::is_same_v<decltype (binary_op (init, x)), float >),
760
+ static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
771
761
" Result type of binary_op must match scan accumulation type." );
772
762
#ifdef __SYCL_DEVICE_ONLY__
773
763
typename Group::linear_id_type local_linear_id =
@@ -804,10 +794,7 @@ std::enable_if_t<
804
794
OutPtr>
805
795
joint_exclusive_scan (Group g, InPtr first, InPtr last, OutPtr result, T init,
806
796
BinaryOperation binary_op) {
807
- // FIXME: Do not special-case for half precision
808
- static_assert (std::is_same_v<decltype (binary_op (init, *first)), T> ||
809
- (std::is_same_v<T, half> &&
810
- std::is_same_v<decltype (binary_op (init, *first)), float >),
797
+ static_assert (std::is_same_v<decltype (binary_op (init, *first)), T>,
811
798
" Result type of binary_op must match scan accumulation type." );
812
799
#ifdef __SYCL_DEVICE_ONLY__
813
800
ptrdiff_t offset = sycl::detail::get_local_linear_id (g);
@@ -859,14 +846,9 @@ std::enable_if_t<
859
846
OutPtr>
860
847
joint_exclusive_scan (Group g, InPtr first, InPtr last, OutPtr result,
861
848
BinaryOperation binary_op) {
862
- // FIXME: Do not special-case for half precision
863
- static_assert (
864
- std::is_same_v<decltype (binary_op (*first, *first)),
865
- typename detail::remove_pointer<OutPtr>::type> ||
866
- (std::is_same_v<typename detail::remove_pointer<OutPtr>::type,
867
- half> &&
868
- std::is_same_v<decltype (binary_op (*first, *first)), float >),
869
- " Result type of binary_op must match scan accumulation type." );
849
+ static_assert (std::is_same_v<decltype (binary_op (*first, *first)),
850
+ typename detail::remove_pointer<OutPtr>::type>,
851
+ " Result type of binary_op must match scan accumulation type." );
870
852
using T = typename detail::remove_pointer<OutPtr>::type;
871
853
T init = detail::identity_for_ga_op<T, BinaryOperation>();
872
854
return joint_exclusive_scan (g, first, last, result, init, binary_op);
@@ -882,15 +864,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
882
864
detail::is_native_op<T, BinaryOperation>::value),
883
865
T>
884
866
inclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
885
- // FIXME: Do not special-case for half precision
886
- static_assert (std::is_same_v<decltype (binary_op (x[0 ], x[0 ])),
887
- typename T::element_type> ||
888
- (std::is_same_v<T, half> &&
889
- std::is_same_v<decltype (binary_op (x[0 ], x[0 ])), float >),
867
+ static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
890
868
" Result type of binary_op must match scan accumulation type." );
891
869
T result;
870
+ typename detail::get_scalar_binary_op<BinaryOperation>::type
871
+ scalar_binary_op{};
892
872
for (int s = 0 ; s < x.size (); ++s) {
893
- result[s] = inclusive_scan_over_group (g, x[s], binary_op );
873
+ result[s] = inclusive_scan_over_group (g, x[s], scalar_binary_op );
894
874
}
895
875
return result;
896
876
}
@@ -903,10 +883,7 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
903
883
detail::is_native_op<T, BinaryOperation>::value),
904
884
T>
905
885
inclusive_scan_over_group (Group g, T x, BinaryOperation binary_op) {
906
- // FIXME: Do not special-case for half precision
907
- static_assert (std::is_same_v<decltype (binary_op (x, x)), T> ||
908
- (std::is_same_v<T, half> &&
909
- std::is_same_v<decltype (binary_op (x, x)), float >),
886
+ static_assert (std::is_same_v<decltype (binary_op (x, x)), T>,
910
887
" Result type of binary_op must match scan accumulation type." );
911
888
#ifdef __SYCL_DEVICE_ONLY__
912
889
#if defined(__NVPTX__)
@@ -959,10 +936,7 @@ std::enable_if_t<
959
936
std::is_convertible_v<V, T>),
960
937
T>
961
938
inclusive_scan_over_group (Group g, V x, BinaryOperation binary_op, T init) {
962
- // FIXME: Do not special-case for half precision
963
- static_assert (std::is_same_v<decltype (binary_op (init, x)), T> ||
964
- (std::is_same_v<T, half> &&
965
- std::is_same_v<decltype (binary_op (init, x)), float >),
939
+ static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
966
940
" Result type of binary_op must match scan accumulation type." );
967
941
#ifdef __SYCL_DEVICE_ONLY__
968
942
T y = x;
@@ -985,14 +959,13 @@ std::enable_if_t<(is_group_v<std::decay_t<Group>> &&
985
959
detail::is_native_op<T, BinaryOperation>::value),
986
960
T>
987
961
inclusive_scan_over_group (Group g, V x, BinaryOperation binary_op, T init) {
988
- // FIXME: Do not special-case for half precision
989
- static_assert (std::is_same_v<decltype (binary_op (init[0 ], x[0 ])), T> ||
990
- (std::is_same_v<T, half> &&
991
- std::is_same_v<decltype (binary_op (init[0 ], x[0 ])), float >),
962
+ static_assert (std::is_same_v<decltype (binary_op (init, x)), T>,
992
963
" Result type of binary_op must match scan accumulation type." );
993
964
T result;
965
+ typename detail::get_scalar_binary_op<BinaryOperation>::type
966
+ scalar_binary_op{};
994
967
for (int s = 0 ; s < x.size (); ++s) {
995
- result[s] = inclusive_scan_over_group (g, x[s], binary_op , init[s]);
968
+ result[s] = inclusive_scan_over_group (g, x[s], scalar_binary_op , init[s]);
996
969
}
997
970
return result;
998
971
}
@@ -1013,10 +986,7 @@ std::enable_if_t<
1013
986
OutPtr>
1014
987
joint_inclusive_scan (Group g, InPtr first, InPtr last, OutPtr result,
1015
988
BinaryOperation binary_op, T init) {
1016
- // FIXME: Do not special-case for half precision
1017
- static_assert (std::is_same_v<decltype (binary_op (init, *first)), T> ||
1018
- (std::is_same_v<T, half> &&
1019
- std::is_same_v<decltype (binary_op (init, *first)), float >),
989
+ static_assert (std::is_same_v<decltype (binary_op (init, *first)), T>,
1020
990
" Result type of binary_op must match scan accumulation type." );
1021
991
#ifdef __SYCL_DEVICE_ONLY__
1022
992
ptrdiff_t offset = sycl::detail::get_local_linear_id (g);
@@ -1065,14 +1035,9 @@ std::enable_if_t<
1065
1035
OutPtr>
1066
1036
joint_inclusive_scan (Group g, InPtr first, InPtr last, OutPtr result,
1067
1037
BinaryOperation binary_op) {
1068
- // FIXME: Do not special-case for half precision
1069
- static_assert (
1070
- std::is_same_v<decltype (binary_op (*first, *first)),
1071
- typename detail::remove_pointer<OutPtr>::type> ||
1072
- (std::is_same_v<typename detail::remove_pointer<OutPtr>::type,
1073
- half> &&
1074
- std::is_same_v<decltype (binary_op (*first, *first)), float >),
1075
- " Result type of binary_op must match scan accumulation type." );
1038
+ static_assert (std::is_same_v<decltype (binary_op (*first, *first)),
1039
+ typename detail::remove_pointer<OutPtr>::type>,
1040
+ " Result type of binary_op must match scan accumulation type." );
1076
1041
1077
1042
using T = typename detail::remove_pointer<OutPtr>::type;
1078
1043
T init = detail::identity_for_ga_op<T, BinaryOperation>();
0 commit comments