@@ -3475,7 +3475,9 @@ struct SequentialSearchReduction
3475
3475
idx_val = static_cast <outT>(m);
3476
3476
}
3477
3477
}
3478
- else if constexpr (std::is_floating_point_v<argT>) {
3478
+ else if constexpr (std::is_floating_point_v<argT> ||
3479
+ std::is_same_v<argT, sycl::half>)
3480
+ {
3479
3481
if (val < red_val || std::isnan (val)) {
3480
3482
red_val = val;
3481
3483
idx_val = static_cast <outT>(m);
@@ -3500,7 +3502,9 @@ struct SequentialSearchReduction
3500
3502
idx_val = static_cast <outT>(m);
3501
3503
}
3502
3504
}
3503
- else if constexpr (std::is_floating_point_v<argT>) {
3505
+ else if constexpr (std::is_floating_point_v<argT> ||
3506
+ std::is_same_v<argT, sycl::half>)
3507
+ {
3504
3508
if (val > red_val || std::isnan (val)) {
3505
3509
red_val = val;
3506
3510
idx_val = static_cast <outT>(m);
@@ -3788,7 +3792,9 @@ struct CustomSearchReduction
3788
3792
}
3789
3793
}
3790
3794
}
3791
- else if constexpr (std::is_floating_point_v<argT>) {
3795
+ else if constexpr (std::is_floating_point_v<argT> ||
3796
+ std::is_same_v<argT, sycl::half>)
3797
+ {
3792
3798
if (val < local_red_val || std::isnan (val)) {
3793
3799
local_red_val = val;
3794
3800
if constexpr (!First) {
@@ -3832,7 +3838,9 @@ struct CustomSearchReduction
3832
3838
}
3833
3839
}
3834
3840
}
3835
- else if constexpr (std::is_floating_point_v<argT>) {
3841
+ else if constexpr (std::is_floating_point_v<argT> ||
3842
+ std::is_same_v<argT, sycl::half>)
3843
+ {
3836
3844
if (val > local_red_val || std::isnan (val)) {
3837
3845
local_red_val = val;
3838
3846
if constexpr (!First) {
@@ -3875,7 +3883,9 @@ struct CustomSearchReduction
3875
3883
? local_idx
3876
3884
: idx_identity_;
3877
3885
}
3878
- else if constexpr (std::is_floating_point_v<argT>) {
3886
+ else if constexpr (std::is_floating_point_v<argT> ||
3887
+ std::is_same_v<argT, sycl::half>)
3888
+ {
3879
3889
// equality does not hold for NaNs, so check here
3880
3890
local_idx =
3881
3891
(red_val_over_wg == local_red_val || std::isnan (local_red_val))
0 commit comments