Skip to content

Commit 909d894

Browse files
committed
Changes per review by @oleksandr-pavlyk
1 parent 821a888 commit 909d894

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,16 @@ template <typename argT, typename resT> struct Expm1Functor
7070
using realT = typename argT::value_type;
7171
// expm1(x + I*y) = expm1(x)*cos(y) - 2*sin(y / 2)^2 +
7272
// I*exp(x)*sin(y)
73-
auto x = std::real(in);
74-
const realT expm1X_val = std::expm1(x);
75-
const realT expX_val = std::exp(x);
73+
const realT x = std::real(in);
74+
const realT y = std::imag(in);
7675

77-
x = std::imag(in);
7876
realT cosY_val;
79-
const realT sinY_val = sycl::sincos(x, &cosY_val);
80-
const realT sinhalfY_val = std::sin(x / realT{2});
77+
const realT sinY_val = sycl::sincos(y, &cosY_val);
78+
const realT sinhalfY_val = std::sin(y / 2);
8179

8280
const realT res_re =
83-
expm1X_val * cosY_val - realT{2} * sinhalfY_val * sinhalfY_val;
84-
const realT res_im = expX_val * sinY_val;
81+
std::expm1(x) * cosY_val - 2 * sinhalfY_val * sinhalfY_val;
82+
const realT res_im = std::exp(x) * sinY_val;
8583
return resT{res_re, res_im};
8684
}
8785
else {

dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,17 @@ template <typename argT, typename resT> struct Log1pFunctor
8080
const realT y = std::imag(in);
8181

8282
// imaginary part of result
83-
const realT imagRes = std::atan2(y, x + 1);
83+
const realT res_im = std::atan2(y, x + 1);
8484

85-
if (std::abs(in) < realT(.5)) {
86-
const realT realRes = x * (2 + x) + y * y;
87-
return {std::log1p(realRes) / 2, imagRes};
85+
if (std::max(std::abs(x), std::abs(y)) < realT(.1)) {
86+
const realT v = x * (2 + x) + y * y;
87+
return {std::log1p(v) / 2, res_im};
8888
}
8989
else {
9090
// when not close to zero,
9191
// prevent overflow
92-
const realT realRes = std::hypot(x + 1, y);
93-
return {std::log(realRes), imagRes};
92+
const realT m = std::hypot(x + 1, y);
93+
return {std::log(m), res_im};
9494
}
9595
}
9696
else {

0 commit comments

Comments
 (0)