Skip to content

Commit acd93ae

Browse files
committed
Search reductions use correct branch for float16
constexpr branch logic accounted for floating point types but not sycl::half, which meant NaNs were not propagating for float16 data
1 parent e5d20db commit acd93ae

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3475,7 +3475,9 @@ struct SequentialSearchReduction
34753475
idx_val = static_cast<outT>(m);
34763476
}
34773477
}
3478-
else if constexpr (std::is_floating_point_v<argT>) {
3478+
else if constexpr (std::is_floating_point_v<argT> ||
3479+
std::is_same_v<argT, sycl::half>)
3480+
{
34793481
if (val < red_val || std::isnan(val)) {
34803482
red_val = val;
34813483
idx_val = static_cast<outT>(m);
@@ -3500,7 +3502,9 @@ struct SequentialSearchReduction
35003502
idx_val = static_cast<outT>(m);
35013503
}
35023504
}
3503-
else if constexpr (std::is_floating_point_v<argT>) {
3505+
else if constexpr (std::is_floating_point_v<argT> ||
3506+
std::is_same_v<argT, sycl::half>)
3507+
{
35043508
if (val > red_val || std::isnan(val)) {
35053509
red_val = val;
35063510
idx_val = static_cast<outT>(m);
@@ -3788,7 +3792,9 @@ struct CustomSearchReduction
37883792
}
37893793
}
37903794
}
3791-
else if constexpr (std::is_floating_point_v<argT>) {
3795+
else if constexpr (std::is_floating_point_v<argT> ||
3796+
std::is_same_v<argT, sycl::half>)
3797+
{
37923798
if (val < local_red_val || std::isnan(val)) {
37933799
local_red_val = val;
37943800
if constexpr (!First) {
@@ -3832,7 +3838,9 @@ struct CustomSearchReduction
38323838
}
38333839
}
38343840
}
3835-
else if constexpr (std::is_floating_point_v<argT>) {
3841+
else if constexpr (std::is_floating_point_v<argT> ||
3842+
std::is_same_v<argT, sycl::half>)
3843+
{
38363844
if (val > local_red_val || std::isnan(val)) {
38373845
local_red_val = val;
38383846
if constexpr (!First) {
@@ -3875,7 +3883,9 @@ struct CustomSearchReduction
38753883
? local_idx
38763884
: idx_identity_;
38773885
}
3878-
else if constexpr (std::is_floating_point_v<argT>) {
3886+
else if constexpr (std::is_floating_point_v<argT> ||
3887+
std::is_same_v<argT, sycl::half>)
3888+
{
38793889
// equality does not hold for NaNs, so check here
38803890
local_idx =
38813891
(red_val_over_wg == local_red_val || std::isnan(local_red_val))

0 commit comments

Comments
 (0)