@@ -61,17 +61,7 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
61
61
62
62
resT operator ()(const argT1 &in1, const argT2 &in2)
63
63
{
64
- resT max = std::max<resT>(in1, in2);
65
- if (std::isnan (max)) {
66
- return std::numeric_limits<resT>::quiet_NaN ();
67
- }
68
- else {
69
- if (std::isinf (max)) {
70
- return std::numeric_limits<resT>::infinity ();
71
- }
72
- }
73
- resT min = std::min<resT>(in1, in2);
74
- return max + std::log1p (std::exp (min - max));
64
+ return impl<resT>(in1, in2);
75
65
}
76
66
77
67
template <int vec_sz>
@@ -83,20 +73,29 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
83
73
84
74
#pragma unroll
85
75
for (int i = 0 ; i < vec_sz; ++i) {
86
- resT max = std::max<resT>(in1[i], in2[i]);
87
- if (std::isnan (max)) {
88
- res[i] = std::numeric_limits<resT>::quiet_NaN ();
89
- }
90
- else if (std::isinf (max)) {
91
- res[i] = std::numeric_limits<resT>::infinity ();
92
- }
93
- else {
94
- res[i] = max + std::log1p (std::exp (std::abs (diff[i])));
95
- }
76
+ res[i] = impl<resT>(in1[i], in2[i]);
96
77
}
97
78
98
79
return res;
99
80
}
81
+
82
+ private:
83
+ template <typename T> T impl (T const &in1, T const &in2)
84
+ {
85
+ T max = std::max<T>(in1, in2);
86
+ if (std::isnan (max)) {
87
+ return std::numeric_limits<T>::quiet_NaN ();
88
+ }
89
+ else {
90
+ if (std::isinf (max)) {
91
+ // if both args are -inf, and hence max is -inf
92
+ // the result is -inf as well
93
+ return max;
94
+ }
95
+ }
96
+ T min = std::min<T>(in1, in2);
97
+ return max + std::log1p (std::exp (min - max));
98
+ }
100
99
};
101
100
102
101
template <typename argT1,
0 commit comments