Skip to content

Commit 119d43d

Browse files
committed
Search reductions use correct branch for float16
constexpr branch logic accounted for floating point types but not sycl::half, which meant NaNs were not propagating for float16 data
1 parent 5709f99 commit 119d43d

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3476,7 +3476,9 @@ struct SequentialSearchReduction
34763476
idx_val = static_cast<outT>(m);
34773477
}
34783478
}
3479-
else if constexpr (std::is_floating_point_v<argT>) {
3479+
else if constexpr (std::is_floating_point_v<argT> ||
3480+
std::is_same_v<argT, sycl::half>)
3481+
{
34803482
if (val < red_val || std::isnan(val)) {
34813483
red_val = val;
34823484
idx_val = static_cast<outT>(m);
@@ -3501,7 +3503,9 @@ struct SequentialSearchReduction
35013503
idx_val = static_cast<outT>(m);
35023504
}
35033505
}
3504-
else if constexpr (std::is_floating_point_v<argT>) {
3506+
else if constexpr (std::is_floating_point_v<argT> ||
3507+
std::is_same_v<argT, sycl::half>)
3508+
{
35053509
if (val > red_val || std::isnan(val)) {
35063510
red_val = val;
35073511
idx_val = static_cast<outT>(m);
@@ -3789,7 +3793,9 @@ struct CustomSearchReduction
37893793
}
37903794
}
37913795
}
3792-
else if constexpr (std::is_floating_point_v<argT>) {
3796+
else if constexpr (std::is_floating_point_v<argT> ||
3797+
std::is_same_v<argT, sycl::half>)
3798+
{
37933799
if (val < local_red_val || std::isnan(val)) {
37943800
local_red_val = val;
37953801
if constexpr (!First) {
@@ -3833,7 +3839,9 @@ struct CustomSearchReduction
38333839
}
38343840
}
38353841
}
3836-
else if constexpr (std::is_floating_point_v<argT>) {
3842+
else if constexpr (std::is_floating_point_v<argT> ||
3843+
std::is_same_v<argT, sycl::half>)
3844+
{
38373845
if (val > local_red_val || std::isnan(val)) {
38383846
local_red_val = val;
38393847
if constexpr (!First) {
@@ -3876,7 +3884,9 @@ struct CustomSearchReduction
38763884
? local_idx
38773885
: idx_identity_;
38783886
}
3879-
else if constexpr (std::is_floating_point_v<argT>) {
3887+
else if constexpr (std::is_floating_point_v<argT> ||
3888+
std::is_same_v<argT, sycl::half>)
3889+
{
38803890
// equality does not hold for NaNs, so check here
38813891
local_idx =
38823892
(red_val_over_wg == local_red_val || std::isnan(local_red_val))

0 commit comments

Comments
 (0)