Skip to content

Commit 19179e6

Browse files
committed
Fixed some complex special cases for expm1
1 parent 341d4da commit 19179e6

File tree

1 file changed

+40
-0
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+40
-0
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <cmath>
2929
#include <cstddef>
3030
#include <cstdint>
31+
#include <limits>
3132
#include <type_traits>
3233

3334
#include "kernels/elementwise_functions/common.hpp"
@@ -73,6 +74,45 @@ template <typename argT, typename resT> struct Expm1Functor
7374
const realT x = std::real(in);
7475
const realT y = std::imag(in);
7576

77+
// special cases
78+
if (std::isinf(x)) {
79+
if (x > realT(0)) {
80+
// positive infinity cases
81+
if (!std::isfinite(y)) {
82+
return resT{x, std::numeric_limits<realT>::quiet_NaN()};
83+
}
84+
else if (y == realT(0)) {
85+
return in;
86+
}
87+
else {
88+
return (std::numeric_limits<realT>::infinity() *
89+
resT{std::cos(y), std::sin(y)} -
90+
realT(1));
91+
}
92+
}
93+
else {
94+
// negative infinity cases
95+
if (!std::isfinite(y)) {
96+
return resT{-1, 0};
97+
}
98+
else {
99+
return (realT(0) * resT{std::cos(y), std::sin(y)} -
100+
realT(1));
101+
}
102+
}
103+
}
104+
105+
if (std::isnan(x)) {
106+
if (y == realT(0)) {
107+
return in;
108+
}
109+
else {
110+
return resT{std::numeric_limits<realT>::quiet_NaN(),
111+
std::numeric_limits<realT>::quiet_NaN()};
112+
}
113+
}
114+
115+
// x, y finite numbers
76116
realT cosY_val;
77117
const realT sinY_val = sycl::sincos(y, &cosY_val);
78118
const realT sinhalfY_val = std::sin(y / 2);

0 commit comments

Comments
 (0)