Skip to content

Commit 3c0aeed

Browse files
Modularized logic implementing logaddexp
If both arguments are -inf, the result is also -inf.
1 parent 3c87433 commit 3c0aeed

File tree

1 file changed

+20
-21
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+20
-21
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,7 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
6161

6262
resT operator()(const argT1 &in1, const argT2 &in2)
6363
{
64-
resT max = std::max<resT>(in1, in2);
65-
if (std::isnan(max)) {
66-
return std::numeric_limits<resT>::quiet_NaN();
67-
}
68-
else {
69-
if (std::isinf(max)) {
70-
return std::numeric_limits<resT>::infinity();
71-
}
72-
}
73-
resT min = std::min<resT>(in1, in2);
74-
return max + std::log1p(std::exp(min - max));
64+
return impl<resT>(in1, in2);
7565
}
7666

7767
template <int vec_sz>
@@ -83,20 +73,29 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
8373

8474
#pragma unroll
8575
for (int i = 0; i < vec_sz; ++i) {
86-
resT max = std::max<resT>(in1[i], in2[i]);
87-
if (std::isnan(max)) {
88-
res[i] = std::numeric_limits<resT>::quiet_NaN();
89-
}
90-
else if (std::isinf(max)) {
91-
res[i] = std::numeric_limits<resT>::infinity();
92-
}
93-
else {
94-
res[i] = max + std::log1p(std::exp(std::abs(diff[i])));
95-
}
76+
res[i] = impl<resT>(in1[i], in2[i]);
9677
}
9778

9879
return res;
9980
}
81+
82+
private:
83+
template <typename T> T impl(T const &in1, T const &in2)
84+
{
85+
T max = std::max<T>(in1, in2);
86+
if (std::isnan(max)) {
87+
return std::numeric_limits<T>::quiet_NaN();
88+
}
89+
else {
90+
if (std::isinf(max)) {
91+
// if both args are -inf, and hence max is -inf
92+
// the result is -inf as well
93+
return max;
94+
}
95+
}
96+
T min = std::min<T>(in1, in2);
97+
return max + std::log1p(std::exp(min - max));
98+
}
10099
};
101100

102101
template <typename argT1,

0 commit comments

Comments
 (0)