Skip to content

Commit 4ae55f8

Browse files
Fix for test failure on AMD CPU.
vec operator should also apply isnan for sycl::half
1 parent 02d46b4 commit 4ae55f8

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,13 @@ template <typename argT1, typename argT2, typename resT> struct MaximumFunctor
7272
}
7373
else if constexpr (std::is_floating_point_v<argT1> ||
7474
std::is_same_v<argT1, sycl::half>)
75-
return (std::isnan(in1) || in1 > in2) ? in1 : in2;
76-
else
75+
{
76+
const bool choose_first = (std::isnan(in1) || (in1 > in2));
77+
return (choose_first) ? in1 : in2;
78+
}
79+
else {
7780
return (in1 > in2) ? in1 : in2;
81+
}
7882
}
7983

8084
template <int vec_sz>
@@ -85,11 +89,17 @@ template <typename argT1, typename argT2, typename resT> struct MaximumFunctor
8589
sycl::vec<resT, vec_sz> res;
8690
#pragma unroll
8791
for (int i = 0; i < vec_sz; ++i) {
88-
if constexpr (std::is_floating_point_v<argT1>)
89-
res[i] =
90-
(sycl::isnan(in1[i]) || in1[i] > in2[i]) ? in1[i] : in2[i];
91-
else
92-
res[i] = (in1[i] > in2[i]) ? in1[i] : in2[i];
92+
const auto &v1 = in1[i];
93+
const auto &v2 = in2[i];
94+
if constexpr (std::is_floating_point_v<argT1> ||
95+
std::is_same_v<argT1, sycl::half>)
96+
{
97+
const bool choose_first = (std::isnan(v1) || (v1 > v2));
98+
res[i] = (choose_first) ? v1 : v2;
99+
}
100+
else {
101+
res[i] = (v1 > v2) ? v1 : v2;
102+
}
93103
}
94104
return res;
95105
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,13 @@ template <typename argT1, typename argT2, typename resT> struct MinimumFunctor
7272
}
7373
else if constexpr (std::is_floating_point_v<argT1> ||
7474
std::is_same_v<argT1, sycl::half>)
75-
return (std::isnan(in1) || in1 < in2) ? in1 : in2;
76-
else
75+
{
76+
const bool choose_first = sycl::isnan(in1) || (in1 < in2);
77+
return (choose_first) ? in1 : in2;
78+
}
79+
else {
7780
return (in1 < in2) ? in1 : in2;
81+
}
7882
}
7983

8084
template <int vec_sz>
@@ -85,11 +89,17 @@ template <typename argT1, typename argT2, typename resT> struct MinimumFunctor
8589
sycl::vec<resT, vec_sz> res;
8690
#pragma unroll
8791
for (int i = 0; i < vec_sz; ++i) {
88-
if constexpr (std::is_floating_point_v<argT1>)
89-
res[i] =
90-
(sycl::isnan(in1[i]) || in1[i] < in2[i]) ? in1[i] : in2[i];
91-
else
92-
res[i] = (in1[i] < in2[i]) ? in1[i] : in2[i];
92+
const auto &v1 = in1[i];
93+
const auto &v2 = in2[i];
94+
if constexpr (std::is_floating_point_v<argT1> ||
95+
std::is_same_v<argT1, sycl::half>)
96+
{
97+
const bool choose_first = sycl::isnan(v1) || (v1 < v2);
98+
res[i] = (choose_first) ? v1 : v2;
99+
}
100+
else {
101+
res[i] = (v1 < v2) ? v1 : v2;
102+
}
93103
}
94104
return res;
95105
}

0 commit comments

Comments
 (0)