Skip to content

Commit 998cc27

Browse files
committed
max and min nan propagation fixed for CPU devices
- drops use of fetch_max/fetch_min for floats, which do not handle nans correctly
1 parent bb48c33 commit 998cc27

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,12 @@ struct ReductionOverGroupWithAtomicFunctor
196196
if constexpr (su_ns::IsPlus<outT, ReductionOp>::value) {
197197
res_ref += red_val_over_wg;
198198
}
199-
else if constexpr (su_ns::IsMaximum<outT, ReductionOp>::value) {
199+
else if constexpr (std::is_same_v<ReductionOp, sycl::maximum<outT>>)
200+
{
200201
res_ref.fetch_max(red_val_over_wg);
201202
}
202-
else if constexpr (su_ns::IsMinimum<outT, ReductionOp>::value) {
203+
else if constexpr (std::is_same_v<ReductionOp, sycl::minimum<outT>>)
204+
{
203205
res_ref.fetch_min(red_val_over_wg);
204206
}
205207
else {
@@ -300,22 +302,11 @@ struct CustomReductionOverGroupWithAtomicFunctor
300302
sycl::memory_scope::device,
301303
sycl::access::address_space::global_space>
302304
res_ref(out_[out_iter_offset]);
303-
if constexpr (su_ns::IsPlus<outT, ReductionOp>::value) {
304-
res_ref += red_val_over_wg;
305-
}
306-
else if constexpr (su_ns::IsMaximum<outT, ReductionOp>::value) {
307-
res_ref.fetch_max(red_val_over_wg);
308-
}
309-
else if constexpr (su_ns::IsMinimum<outT, ReductionOp>::value) {
310-
res_ref.fetch_min(red_val_over_wg);
311-
}
312-
else {
313-
outT read_val = res_ref.load();
314-
outT new_val{};
315-
do {
316-
new_val = reduction_op_(read_val, red_val_over_wg);
317-
} while (!res_ref.compare_exchange_strong(read_val, new_val));
318-
}
305+
outT read_val = res_ref.load();
306+
outT new_val{};
307+
do {
308+
new_val = reduction_op_(read_val, red_val_over_wg);
309+
} while (!res_ref.compare_exchange_strong(read_val, new_val));
319310
}
320311
}
321312
};

0 commit comments

Comments
 (0)