Skip to content

Commit ee46ae1

Browse files
Common reduction template takes functions to test if atomics are applicable
Passing these function pointers around allows to turn atomic off altogether if desired. Use custom trait to check if reduce_over_groups can be used. This allows to work-around bug, or switch to custom code for reduction over group if desired. Such custom trait type works around issue with incorrect result returned from sycl::reduce_over_group for sycl::multiplies operator for 64-bit integral types.
1 parent ca0ff64 commit ee46ae1

File tree

3 files changed

+152
-98
lines changed

3 files changed

+152
-98
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 85 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ namespace tensor
5050
namespace kernels
5151
{
5252

53+
template <typename ReductionOpT, typename T> struct can_use_reduce_over_group
54+
{
55+
static constexpr bool value =
56+
sycl::has_known_identity<ReductionOpT, T>::value &&
57+
!std::is_same_v<T, std::int64_t> && !std::is_same_v<T, std::uint64_t> &&
58+
!std::is_same_v<ReductionOpT, sycl::multiplies<T>>;
59+
};
60+
5361
template <typename argT,
5462
typename outT,
5563
typename ReductionOp,
@@ -477,7 +485,8 @@ sycl::event reduction_over_group_with_atomics_strided_impl(
477485
sycl::range<1>{iter_nelems * reduction_groups * wg};
478486
auto localRange = sycl::range<1>{wg};
479487

480-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
488+
if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
489+
{
481490
using KernelName = class reduction_over_group_with_atomics_krn<
482491
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
483492
ReductionIndexerT>;
@@ -618,7 +627,8 @@ sycl::event reduction_axis1_over_group_with_atomics_contig_impl(
618627
sycl::range<1>{iter_nelems * reduction_groups * wg};
619628
auto localRange = sycl::range<1>{wg};
620629

621-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
630+
if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
631+
{
622632
using KernelName =
623633
class reduction_axis1_over_group_with_atomics_contig_krn<
624634
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
@@ -717,7 +727,8 @@ sycl::event reduction_axis0_over_group_with_atomics_contig_impl(
717727
sycl::range<1>{iter_nelems * reduction_groups * wg};
718728
auto localRange = sycl::range<1>{wg};
719729

720-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
730+
if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
731+
{
721732
using KernelName =
722733
class reduction_axis0_over_group_with_atomics_contig_krn<
723734
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
@@ -1007,10 +1018,12 @@ sycl::event reduction_over_group_temps_strided_impl(
10071018
sycl::range<1>{iter_nelems * reduction_groups * wg};
10081019
auto localRange = sycl::range<1>{wg};
10091020

1010-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1021+
if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
1022+
{
10111023
using KernelName = class reduction_over_group_temps_krn<
10121024
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
10131025
ReductionIndexerT>;
1026+
10141027
cgh.parallel_for<KernelName>(
10151028
sycl::nd_range<1>(globalRange, localRange),
10161029
ReductionOverGroupNoAtomicFunctor<
@@ -1026,6 +1039,7 @@ sycl::event reduction_over_group_temps_strided_impl(
10261039
using KernelName = class custom_reduction_over_group_temps_krn<
10271040
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
10281041
ReductionIndexerT, SlmT>;
1042+
10291043
cgh.parallel_for<KernelName>(
10301044
sycl::nd_range<1>(globalRange, localRange),
10311045
CustomReductionOverGroupNoAtomicFunctor<
@@ -1062,68 +1076,67 @@ sycl::event reduction_over_group_temps_strided_impl(
10621076
partially_reduced_tmp + reduction_groups * iter_nelems;
10631077
}
10641078

1065-
const sycl::event &first_reduction_ev =
1066-
exec_q.submit([&](sycl::handler &cgh) {
1067-
cgh.depends_on(depends);
1079+
const sycl::event &first_reduction_ev = exec_q.submit([&](sycl::handler
1080+
&cgh) {
1081+
cgh.depends_on(depends);
10681082

1069-
using InputIndexerT =
1070-
dpctl::tensor::offset_utils::StridedIndexer;
1071-
using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
1072-
using InputOutputIterIndexerT =
1073-
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
1074-
InputIndexerT, ResIndexerT>;
1075-
using ReductionIndexerT =
1076-
dpctl::tensor::offset_utils::StridedIndexer;
1083+
using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
1084+
using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
1085+
using InputOutputIterIndexerT =
1086+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
1087+
InputIndexerT, ResIndexerT>;
1088+
using ReductionIndexerT =
1089+
dpctl::tensor::offset_utils::StridedIndexer;
10771090

1078-
// Only 2*iter_nd entries describing shape and strides of
1079-
// iterated dimensions of input array from
1080-
// iter_shape_and_strides are going to be accessed by
1081-
// inp_indexer
1082-
InputIndexerT inp_indexer(iter_nd, iter_arg_offset,
1083-
iter_shape_and_strides);
1084-
ResIndexerT noop_tmp_indexer{};
1091+
// Only 2*iter_nd entries describing shape and strides of
1092+
// iterated dimensions of input array from
1093+
// iter_shape_and_strides are going to be accessed by
1094+
// inp_indexer
1095+
InputIndexerT inp_indexer(iter_nd, iter_arg_offset,
1096+
iter_shape_and_strides);
1097+
ResIndexerT noop_tmp_indexer{};
10851098

1086-
InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
1087-
noop_tmp_indexer};
1088-
ReductionIndexerT reduction_indexer{
1089-
red_nd, reduction_arg_offset, reduction_shape_stride};
1099+
InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
1100+
noop_tmp_indexer};
1101+
ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
1102+
reduction_shape_stride};
10901103

1091-
auto globalRange =
1092-
sycl::range<1>{iter_nelems * reduction_groups * wg};
1093-
auto localRange = sycl::range<1>{wg};
1104+
auto globalRange =
1105+
sycl::range<1>{iter_nelems * reduction_groups * wg};
1106+
auto localRange = sycl::range<1>{wg};
10941107

1095-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1096-
using KernelName = class reduction_over_group_temps_krn<
1108+
if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
1109+
{
1110+
using KernelName = class reduction_over_group_temps_krn<
1111+
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1112+
ReductionIndexerT>;
1113+
cgh.parallel_for<KernelName>(
1114+
sycl::nd_range<1>(globalRange, localRange),
1115+
ReductionOverGroupNoAtomicFunctor<
10971116
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1098-
ReductionIndexerT>;
1099-
cgh.parallel_for<KernelName>(
1100-
sycl::nd_range<1>(globalRange, localRange),
1101-
ReductionOverGroupNoAtomicFunctor<
1102-
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1103-
ReductionIndexerT>(
1104-
arg_tp, partially_reduced_tmp, ReductionOpT(),
1105-
identity_val, in_out_iter_indexer,
1106-
reduction_indexer, reduction_nelems, iter_nelems,
1107-
preferrered_reductions_per_wi));
1108-
}
1109-
else {
1110-
using SlmT = sycl::local_accessor<resTy, 1>;
1111-
SlmT local_memory = SlmT(localRange, cgh);
1112-
using KernelName =
1113-
class custom_reduction_over_group_temps_krn<
1114-
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1115-
ReductionIndexerT, SlmT>;
1116-
cgh.parallel_for<KernelName>(
1117-
sycl::nd_range<1>(globalRange, localRange),
1118-
CustomReductionOverGroupNoAtomicFunctor<
1119-
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1120-
ReductionIndexerT, SlmT>(
1121-
arg_tp, partially_reduced_tmp, ReductionOpT(),
1122-
identity_val, in_out_iter_indexer,
1123-
reduction_indexer, local_memory, reduction_nelems,
1124-
iter_nelems, preferrered_reductions_per_wi));
1125-
}
1126-
});
1117+
ReductionIndexerT>(
1118+
arg_tp, partially_reduced_tmp, ReductionOpT(),
1119+
identity_val, in_out_iter_indexer, reduction_indexer,
1120+
reduction_nelems, iter_nelems,
1121+
preferrered_reductions_per_wi));
1122+
}
1123+
else {
1124+
using SlmT = sycl::local_accessor<resTy, 1>;
1125+
SlmT local_memory = SlmT(localRange, cgh);
1126+
using KernelName = class custom_reduction_over_group_temps_krn<
1127+
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1128+
ReductionIndexerT, SlmT>;
1129+
cgh.parallel_for<KernelName>(
1130+
sycl::nd_range<1>(globalRange, localRange),
1131+
CustomReductionOverGroupNoAtomicFunctor<
1132+
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1133+
ReductionIndexerT, SlmT>(
1134+
arg_tp, partially_reduced_tmp, ReductionOpT(),
1135+
identity_val, in_out_iter_indexer, reduction_indexer,
1136+
local_memory, reduction_nelems, iter_nelems,
1137+
preferrered_reductions_per_wi));
1138+
}
1139+
});
11271140

11281141
size_t remaining_reduction_nelems = reduction_groups;
11291142

@@ -1165,7 +1178,8 @@ sycl::event reduction_over_group_temps_strided_impl(
11651178
auto globalRange =
11661179
sycl::range<1>{iter_nelems * reduction_groups_ * wg};
11671180
auto localRange = sycl::range<1>{wg};
1168-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1181+
if constexpr (can_use_reduce_over_group<ReductionOpT,
1182+
resTy>::value) {
11691183
using KernelName = class reduction_over_group_temps_krn<
11701184
resTy, resTy, ReductionOpT, InputOutputIterIndexerT,
11711185
ReductionIndexerT>;
@@ -1240,7 +1254,8 @@ sycl::event reduction_over_group_temps_strided_impl(
12401254
sycl::range<1>{iter_nelems * reduction_groups * wg};
12411255
auto localRange = sycl::range<1>{wg};
12421256

1243-
if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1257+
if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
1258+
{
12441259
using KernelName = class reduction_over_group_temps_krn<
12451260
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
12461261
ReductionIndexerT>;
@@ -2564,7 +2579,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
25642579
sycl::range<1>{iter_nelems * reduction_groups * wg};
25652580
auto localRange = sycl::range<1>{wg};
25662581

2567-
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2582+
if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
2583+
{
25682584
using KernelName = class search_reduction_over_group_temps_krn<
25692585
argTy, resTy, ReductionOpT, IndexOpT,
25702586
InputOutputIterIndexerT, ReductionIndexerT, true, true>;
@@ -2663,7 +2679,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
26632679
sycl::range<1>{iter_nelems * reduction_groups * wg};
26642680
auto localRange = sycl::range<1>{wg};
26652681

2666-
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2682+
if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
2683+
{
26672684
using KernelName = class search_reduction_over_group_temps_krn<
26682685
argTy, resTy, ReductionOpT, IndexOpT,
26692686
InputOutputIterIndexerT, ReductionIndexerT, true, false>;
@@ -2743,7 +2760,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
27432760
auto globalRange =
27442761
sycl::range<1>{iter_nelems * reduction_groups_ * wg};
27452762
auto localRange = sycl::range<1>{wg};
2746-
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2763+
if constexpr (can_use_reduce_over_group<ReductionOpT,
2764+
resTy>::value) {
27472765
using KernelName =
27482766
class search_reduction_over_group_temps_krn<
27492767
argTy, resTy, ReductionOpT, IndexOpT,
@@ -2826,7 +2844,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
28262844
sycl::range<1>{iter_nelems * reduction_groups * wg};
28272845
auto localRange = sycl::range<1>{wg};
28282846

2829-
if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2847+
if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
2848+
{
28302849
using KernelName = class search_reduction_over_group_temps_krn<
28312850
argTy, resTy, ReductionOpT, IndexOpT,
28322851
InputOutputIterIndexerT, ReductionIndexerT, false, true>;

0 commit comments

Comments
 (0)