Skip to content

Commit e0ce9e4

Browse files
committed
max and min nan propagation fixed for CPU devices
- drops use of fetch_max/fetch_min for floats, which do not handle nans correctly
1 parent 419c117 commit e0ce9e4

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,12 @@ struct ReductionOverGroupWithAtomicFunctor
195195
if constexpr (su_ns::IsPlus<outT, ReductionOp>::value) {
196196
res_ref += red_val_over_wg;
197197
}
198-
else if constexpr (su_ns::IsMaximum<outT, ReductionOp>::value) {
198+
else if constexpr (std::is_same_v<ReductionOp, sycl::maximum<outT>>)
199+
{
199200
res_ref.fetch_max(red_val_over_wg);
200201
}
201-
else if constexpr (su_ns::IsMinimum<outT, ReductionOp>::value) {
202+
else if constexpr (std::is_same_v<ReductionOp, sycl::minimum<outT>>)
203+
{
202204
res_ref.fetch_min(red_val_over_wg);
203205
}
204206
else {
@@ -299,22 +301,11 @@ struct CustomReductionOverGroupWithAtomicFunctor
299301
sycl::memory_scope::device,
300302
sycl::access::address_space::global_space>
301303
res_ref(out_[out_iter_offset]);
302-
if constexpr (su_ns::IsPlus<outT, ReductionOp>::value) {
303-
res_ref += red_val_over_wg;
304-
}
305-
else if constexpr (su_ns::IsMaximum<outT, ReductionOp>::value) {
306-
res_ref.fetch_max(red_val_over_wg);
307-
}
308-
else if constexpr (su_ns::IsMinimum<outT, ReductionOp>::value) {
309-
res_ref.fetch_min(red_val_over_wg);
310-
}
311-
else {
312-
outT read_val = res_ref.load();
313-
outT new_val{};
314-
do {
315-
new_val = reduction_op_(read_val, red_val_over_wg);
316-
} while (!res_ref.compare_exchange_strong(read_val, new_val));
317-
}
304+
outT read_val = res_ref.load();
305+
outT new_val{};
306+
do {
307+
new_val = reduction_op_(read_val, red_val_over_wg);
308+
} while (!res_ref.compare_exchange_strong(read_val, new_val));
318309
}
319310
}
320311
};

0 commit comments

Comments
 (0)