@@ -50,6 +50,14 @@ namespace tensor
50
50
namespace kernels
51
51
{
52
52
53
+ template <typename ReductionOpT, typename T> struct can_use_reduce_over_group
54
+ {
55
+ static constexpr bool value =
56
+ sycl::has_known_identity<ReductionOpT, T>::value &&
57
+ !std::is_same_v<T, std::int64_t > && !std::is_same_v<T, std::uint64_t > &&
58
+ !std::is_same_v<ReductionOpT, sycl::multiplies<T>>;
59
+ };
60
+
53
61
template <typename argT,
54
62
typename outT,
55
63
typename ReductionOp,
@@ -477,7 +485,8 @@ sycl::event reduction_over_group_with_atomics_strided_impl(
477
485
sycl::range<1 >{iter_nelems * reduction_groups * wg};
478
486
auto localRange = sycl::range<1 >{wg};
479
487
480
- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
488
+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
489
+ {
481
490
using KernelName = class reduction_over_group_with_atomics_krn <
482
491
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
483
492
ReductionIndexerT>;
@@ -618,7 +627,8 @@ sycl::event reduction_axis1_over_group_with_atomics_contig_impl(
618
627
sycl::range<1 >{iter_nelems * reduction_groups * wg};
619
628
auto localRange = sycl::range<1 >{wg};
620
629
621
- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
630
+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
631
+ {
622
632
using KernelName =
623
633
class reduction_axis1_over_group_with_atomics_contig_krn <
624
634
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
@@ -717,7 +727,8 @@ sycl::event reduction_axis0_over_group_with_atomics_contig_impl(
717
727
sycl::range<1 >{iter_nelems * reduction_groups * wg};
718
728
auto localRange = sycl::range<1 >{wg};
719
729
720
- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
730
+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
731
+ {
721
732
using KernelName =
722
733
class reduction_axis0_over_group_with_atomics_contig_krn <
723
734
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
@@ -1007,10 +1018,12 @@ sycl::event reduction_over_group_temps_strided_impl(
1007
1018
sycl::range<1 >{iter_nelems * reduction_groups * wg};
1008
1019
auto localRange = sycl::range<1 >{wg};
1009
1020
1010
- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1021
+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
1022
+ {
1011
1023
using KernelName = class reduction_over_group_temps_krn <
1012
1024
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1013
1025
ReductionIndexerT>;
1026
+
1014
1027
cgh.parallel_for <KernelName>(
1015
1028
sycl::nd_range<1 >(globalRange, localRange),
1016
1029
ReductionOverGroupNoAtomicFunctor<
@@ -1026,6 +1039,7 @@ sycl::event reduction_over_group_temps_strided_impl(
1026
1039
using KernelName = class custom_reduction_over_group_temps_krn <
1027
1040
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1028
1041
ReductionIndexerT, SlmT>;
1042
+
1029
1043
cgh.parallel_for <KernelName>(
1030
1044
sycl::nd_range<1 >(globalRange, localRange),
1031
1045
CustomReductionOverGroupNoAtomicFunctor<
@@ -1062,68 +1076,67 @@ sycl::event reduction_over_group_temps_strided_impl(
1062
1076
partially_reduced_tmp + reduction_groups * iter_nelems;
1063
1077
}
1064
1078
1065
- const sycl::event &first_reduction_ev =
1066
- exec_q. submit ([&](sycl::handler &cgh) {
1067
- cgh.depends_on (depends);
1079
+ const sycl::event &first_reduction_ev = exec_q. submit ([&](sycl::handler
1080
+ &cgh) {
1081
+ cgh.depends_on (depends);
1068
1082
1069
- using InputIndexerT =
1070
- dpctl::tensor::offset_utils::StridedIndexer;
1071
- using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
1072
- using InputOutputIterIndexerT =
1073
- dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
1074
- InputIndexerT, ResIndexerT>;
1075
- using ReductionIndexerT =
1076
- dpctl::tensor::offset_utils::StridedIndexer;
1083
+ using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
1084
+ using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
1085
+ using InputOutputIterIndexerT =
1086
+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
1087
+ InputIndexerT, ResIndexerT>;
1088
+ using ReductionIndexerT =
1089
+ dpctl::tensor::offset_utils::StridedIndexer;
1077
1090
1078
- // Only 2*iter_nd entries describing shape and strides of
1079
- // iterated dimensions of input array from
1080
- // iter_shape_and_strides are going to be accessed by
1081
- // inp_indexer
1082
- InputIndexerT inp_indexer (iter_nd, iter_arg_offset,
1083
- iter_shape_and_strides);
1084
- ResIndexerT noop_tmp_indexer{};
1091
+ // Only 2*iter_nd entries describing shape and strides of
1092
+ // iterated dimensions of input array from
1093
+ // iter_shape_and_strides are going to be accessed by
1094
+ // inp_indexer
1095
+ InputIndexerT inp_indexer (iter_nd, iter_arg_offset,
1096
+ iter_shape_and_strides);
1097
+ ResIndexerT noop_tmp_indexer{};
1085
1098
1086
- InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
1087
- noop_tmp_indexer};
1088
- ReductionIndexerT reduction_indexer{
1089
- red_nd, reduction_arg_offset, reduction_shape_stride};
1099
+ InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
1100
+ noop_tmp_indexer};
1101
+ ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
1102
+ reduction_shape_stride};
1090
1103
1091
- auto globalRange =
1092
- sycl::range<1 >{iter_nelems * reduction_groups * wg};
1093
- auto localRange = sycl::range<1 >{wg};
1104
+ auto globalRange =
1105
+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
1106
+ auto localRange = sycl::range<1 >{wg};
1094
1107
1095
- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1096
- using KernelName = class reduction_over_group_temps_krn <
1108
+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
1109
+ {
1110
+ using KernelName = class reduction_over_group_temps_krn <
1111
+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1112
+ ReductionIndexerT>;
1113
+ cgh.parallel_for <KernelName>(
1114
+ sycl::nd_range<1 >(globalRange, localRange),
1115
+ ReductionOverGroupNoAtomicFunctor<
1097
1116
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1098
- ReductionIndexerT>;
1099
- cgh.parallel_for <KernelName>(
1100
- sycl::nd_range<1 >(globalRange, localRange),
1101
- ReductionOverGroupNoAtomicFunctor<
1102
- argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1103
- ReductionIndexerT>(
1104
- arg_tp, partially_reduced_tmp, ReductionOpT (),
1105
- identity_val, in_out_iter_indexer,
1106
- reduction_indexer, reduction_nelems, iter_nelems,
1107
- preferrered_reductions_per_wi));
1108
- }
1109
- else {
1110
- using SlmT = sycl::local_accessor<resTy, 1 >;
1111
- SlmT local_memory = SlmT (localRange, cgh);
1112
- using KernelName =
1113
- class custom_reduction_over_group_temps_krn <
1114
- argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1115
- ReductionIndexerT, SlmT>;
1116
- cgh.parallel_for <KernelName>(
1117
- sycl::nd_range<1 >(globalRange, localRange),
1118
- CustomReductionOverGroupNoAtomicFunctor<
1119
- argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1120
- ReductionIndexerT, SlmT>(
1121
- arg_tp, partially_reduced_tmp, ReductionOpT (),
1122
- identity_val, in_out_iter_indexer,
1123
- reduction_indexer, local_memory, reduction_nelems,
1124
- iter_nelems, preferrered_reductions_per_wi));
1125
- }
1126
- });
1117
+ ReductionIndexerT>(
1118
+ arg_tp, partially_reduced_tmp, ReductionOpT (),
1119
+ identity_val, in_out_iter_indexer, reduction_indexer,
1120
+ reduction_nelems, iter_nelems,
1121
+ preferrered_reductions_per_wi));
1122
+ }
1123
+ else {
1124
+ using SlmT = sycl::local_accessor<resTy, 1 >;
1125
+ SlmT local_memory = SlmT (localRange, cgh);
1126
+ using KernelName = class custom_reduction_over_group_temps_krn <
1127
+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1128
+ ReductionIndexerT, SlmT>;
1129
+ cgh.parallel_for <KernelName>(
1130
+ sycl::nd_range<1 >(globalRange, localRange),
1131
+ CustomReductionOverGroupNoAtomicFunctor<
1132
+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1133
+ ReductionIndexerT, SlmT>(
1134
+ arg_tp, partially_reduced_tmp, ReductionOpT (),
1135
+ identity_val, in_out_iter_indexer, reduction_indexer,
1136
+ local_memory, reduction_nelems, iter_nelems,
1137
+ preferrered_reductions_per_wi));
1138
+ }
1139
+ });
1127
1140
1128
1141
size_t remaining_reduction_nelems = reduction_groups;
1129
1142
@@ -1165,7 +1178,8 @@ sycl::event reduction_over_group_temps_strided_impl(
1165
1178
auto globalRange =
1166
1179
sycl::range<1 >{iter_nelems * reduction_groups_ * wg};
1167
1180
auto localRange = sycl::range<1 >{wg};
1168
- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1181
+ if constexpr (can_use_reduce_over_group<ReductionOpT,
1182
+ resTy>::value) {
1169
1183
using KernelName = class reduction_over_group_temps_krn <
1170
1184
resTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1171
1185
ReductionIndexerT>;
@@ -1240,7 +1254,8 @@ sycl::event reduction_over_group_temps_strided_impl(
1240
1254
sycl::range<1 >{iter_nelems * reduction_groups * wg};
1241
1255
auto localRange = sycl::range<1 >{wg};
1242
1256
1243
- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1257
+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
1258
+ {
1244
1259
using KernelName = class reduction_over_group_temps_krn <
1245
1260
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1246
1261
ReductionIndexerT>;
@@ -2564,7 +2579,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2564
2579
sycl::range<1 >{iter_nelems * reduction_groups * wg};
2565
2580
auto localRange = sycl::range<1 >{wg};
2566
2581
2567
- if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2582
+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
2583
+ {
2568
2584
using KernelName = class search_reduction_over_group_temps_krn <
2569
2585
argTy, resTy, ReductionOpT, IndexOpT,
2570
2586
InputOutputIterIndexerT, ReductionIndexerT, true , true >;
@@ -2663,7 +2679,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2663
2679
sycl::range<1 >{iter_nelems * reduction_groups * wg};
2664
2680
auto localRange = sycl::range<1 >{wg};
2665
2681
2666
- if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2682
+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
2683
+ {
2667
2684
using KernelName = class search_reduction_over_group_temps_krn <
2668
2685
argTy, resTy, ReductionOpT, IndexOpT,
2669
2686
InputOutputIterIndexerT, ReductionIndexerT, true , false >;
@@ -2743,7 +2760,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2743
2760
auto globalRange =
2744
2761
sycl::range<1 >{iter_nelems * reduction_groups_ * wg};
2745
2762
auto localRange = sycl::range<1 >{wg};
2746
- if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2763
+ if constexpr (can_use_reduce_over_group<ReductionOpT,
2764
+ resTy>::value) {
2747
2765
using KernelName =
2748
2766
class search_reduction_over_group_temps_krn <
2749
2767
argTy, resTy, ReductionOpT, IndexOpT,
@@ -2826,7 +2844,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
2826
2844
sycl::range<1 >{iter_nelems * reduction_groups * wg};
2827
2845
auto localRange = sycl::range<1 >{wg};
2828
2846
2829
- if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2847
+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
2848
+ {
2830
2849
using KernelName = class search_reduction_over_group_temps_krn <
2831
2850
argTy, resTy, ReductionOpT, IndexOpT,
2832
2851
InputOutputIterIndexerT, ReductionIndexerT, false , true >;
0 commit comments