Skip to content

Commit 3c87433

Browse files
committed
Broke up 'or' conditional in logaddexp logic for inf and NaN
- 'or' conditions can sometimes cause wrong results when using the OS compiler
1 parent 1b5419f commit 3c87433

File tree

1 file changed

+12
-4
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+12
-4
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,13 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
6262
resT operator()(const argT1 &in1, const argT2 &in2)
6363
{
6464
resT max = std::max<resT>(in1, in2);
65-
if (std::isnan(max) || std::isinf(max)) {
66-
return max;
65+
if (std::isnan(max)) {
66+
return std::numeric_limits<resT>::quiet_NaN();
67+
}
68+
else {
69+
if (std::isinf(max)) {
70+
return std::numeric_limits<resT>::infinity();
71+
}
6772
}
6873
resT min = std::min<resT>(in1, in2);
6974
return max + std::log1p(std::exp(min - max));
@@ -79,8 +84,11 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
7984
#pragma unroll
8085
for (int i = 0; i < vec_sz; ++i) {
8186
resT max = std::max<resT>(in1[i], in2[i]);
82-
if (std::isnan(max) || std::isinf(max)) {
83-
res[i] = max;
87+
if (std::isnan(max)) {
88+
res[i] = std::numeric_limits<resT>::quiet_NaN();
89+
}
90+
else if (std::isinf(max)) {
91+
res[i] = std::numeric_limits<resT>::infinity();
8492
}
8593
else {
8694
res[i] = max + std::log1p(std::exp(std::abs(diff[i])));

0 commit comments

Comments
 (0)