Skip to content

Commit 1b5419f

Browse files
committed
logaddexp now handles both NaNs and infinities correctly per array API
1 parent ff1081a commit 1b5419f

File tree

1 file changed

+6
-6
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+6
-6
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
6161

6262
resT operator()(const argT1 &in1, const argT2 &in2)
6363
{
64-
if (std::isnan(in1) || std::isnan(in2)) {
65-
return std::numeric_limits<resT>::quiet_NaN();
66-
}
6764
resT max = std::max<resT>(in1, in2);
65+
if (std::isnan(max) || std::isinf(max)) {
66+
return max;
67+
}
6868
resT min = std::min<resT>(in1, in2);
6969
return max + std::log1p(std::exp(min - max));
7070
}
@@ -78,11 +78,11 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
7878

7979
#pragma unroll
8080
for (int i = 0; i < vec_sz; ++i) {
81-
if (std::isnan(in1[i]) || std::isnan(in2[i])) {
82-
res[i] = std::numeric_limits<resT>::quiet_NaN();
81+
resT max = std::max<resT>(in1[i], in2[i]);
82+
if (std::isnan(max) || std::isinf(max)) {
83+
res[i] = max;
8384
}
8485
else {
85-
resT max = std::max<resT>(in1[i], in2[i]);
8686
res[i] = max + std::log1p(std::exp(std::abs(diff[i])));
8787
}
8888
}

0 commit comments

Comments
 (0)