@@ -3476,7 +3476,9 @@ struct SequentialSearchReduction
3476
3476
idx_val = static_cast <outT>(m);
3477
3477
}
3478
3478
}
3479
- else if constexpr (std::is_floating_point_v<argT>) {
3479
+ else if constexpr (std::is_floating_point_v<argT> ||
3480
+ std::is_same_v<argT, sycl::half>)
3481
+ {
3480
3482
if (val < red_val || std::isnan (val)) {
3481
3483
red_val = val;
3482
3484
idx_val = static_cast <outT>(m);
@@ -3501,7 +3503,9 @@ struct SequentialSearchReduction
3501
3503
idx_val = static_cast <outT>(m);
3502
3504
}
3503
3505
}
3504
- else if constexpr (std::is_floating_point_v<argT>) {
3506
+ else if constexpr (std::is_floating_point_v<argT> ||
3507
+ std::is_same_v<argT, sycl::half>)
3508
+ {
3505
3509
if (val > red_val || std::isnan (val)) {
3506
3510
red_val = val;
3507
3511
idx_val = static_cast <outT>(m);
@@ -3789,7 +3793,9 @@ struct CustomSearchReduction
3789
3793
}
3790
3794
}
3791
3795
}
3792
- else if constexpr (std::is_floating_point_v<argT>) {
3796
+ else if constexpr (std::is_floating_point_v<argT> ||
3797
+ std::is_same_v<argT, sycl::half>)
3798
+ {
3793
3799
if (val < local_red_val || std::isnan (val)) {
3794
3800
local_red_val = val;
3795
3801
if constexpr (!First) {
@@ -3833,7 +3839,9 @@ struct CustomSearchReduction
3833
3839
}
3834
3840
}
3835
3841
}
3836
- else if constexpr (std::is_floating_point_v<argT>) {
3842
+ else if constexpr (std::is_floating_point_v<argT> ||
3843
+ std::is_same_v<argT, sycl::half>)
3844
+ {
3837
3845
if (val > local_red_val || std::isnan (val)) {
3838
3846
local_red_val = val;
3839
3847
if constexpr (!First) {
@@ -3876,7 +3884,9 @@ struct CustomSearchReduction
3876
3884
? local_idx
3877
3885
: idx_identity_;
3878
3886
}
3879
- else if constexpr (std::is_floating_point_v<argT>) {
3887
+ else if constexpr (std::is_floating_point_v<argT> ||
3888
+ std::is_same_v<argT, sycl::half>)
3889
+ {
3880
3890
// equality does not hold for NaNs, so check here
3881
3891
local_idx =
3882
3892
(red_val_over_wg == local_red_val || std::isnan (local_red_val))
0 commit comments