Skip to content

Commit bb52bb1

Browse files
Avoid using sycl::ilogb, but use own implementation
ilogb would have to pay attention to correctly computing scale of denormal floats, while simpler code suffices. Also use unscaled version in most cases, and scaled version only for very large inputs.
1 parent c4312cb commit bb52bb1

File tree

1 file changed

+82
-3
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+82
-3
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,42 @@ template <typename argT, typename resT> struct SqrtFunctor
137137
}
138138
}
139139

140+
int get_normal_scale_float(const float &v) const
141+
{
142+
constexpr int float_significant_bits = 23;
143+
constexpr std::uint32_t exponent_mask = 0xff;
144+
constexpr int exponent_bias = 127;
145+
const int scale = static_cast<int>(
146+
(sycl::bit_cast<std::uint32_t>(v) >> float_significant_bits) &
147+
exponent_mask);
148+
return scale - exponent_bias;
149+
}
150+
151+
int get_normal_scale_double(const double &v) const
152+
{
153+
constexpr int float_significant_bits = 53;
154+
constexpr std::uint64_t exponent_mask = 0x7ff;
155+
constexpr int exponent_bias = 1023;
156+
const int scale = static_cast<int>(
157+
(sycl::bit_cast<std::uint64_t>(v) >> float_significant_bits) &
158+
exponent_mask);
159+
return scale - exponent_bias;
160+
}
161+
162+
template <typename T> int get_normal_scale(const T &v) const
163+
{
164+
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>);
165+
166+
if constexpr (std::is_same_v<T, float>) {
167+
return get_normal_scale_float(v);
168+
}
169+
else {
170+
return get_normal_scale_double(v);
171+
}
172+
}
173+
140174
template <typename T>
141-
std::complex<T> csqrt_finite(T const &x, T const &y) const
175+
std::complex<T> csqrt_finite_scaled(T const &x, T const &y) const
142176
{
143177
// csqrt(x + y*1j) =
144178
// sqrt((cabs(x, y) + x) / 2) +
@@ -148,8 +182,8 @@ template <typename argT, typename resT> struct SqrtFunctor
148182
constexpr realT half = realT(0x1.0p-1f); // 1/2
149183
constexpr realT zero = realT(0);
150184

151-
const int exp_x = sycl::ilogb(x);
152-
const int exp_y = sycl::ilogb(y);
185+
const int exp_x = get_normal_scale<realT>(x);
186+
const int exp_y = get_normal_scale<realT>(y);
153187

154188
int sc = std::max<int>(exp_x, exp_y) / 2;
155189
const realT xx = sycl::ldexp(x, -sc * 2);
@@ -170,6 +204,51 @@ template <typename argT, typename resT> struct SqrtFunctor
170204
return {sycl::ldexp(d, sc), sycl::ldexp(res_im, sc)};
171205
}
172206
}
207+
208+
template <typename T>
209+
std::complex<T> csqrt_finite_unscaled(T const &x, T const &y) const
210+
{
211+
// csqrt(x + y*1j) =
212+
// sqrt((cabs(x, y) + x) / 2) +
213+
// 1j * copysign(sqrt((cabs(x, y) - x) / 2), y)
214+
215+
using realT = T;
216+
constexpr realT half = realT(0x1.0p-1f); // 1/2
217+
constexpr realT zero = realT(0);
218+
219+
if (std::signbit(x)) {
220+
const realT m = std::hypot(x, y);
221+
const realT d = std::sqrt((m - x) * half);
222+
const realT res_re = (d == zero ? zero : std::abs(y) / d * half);
223+
const realT res_im = std::copysign(d, y);
224+
return {res_re, res_im};
225+
}
226+
else {
227+
const realT m = std::hypot(x, y);
228+
const realT d = std::sqrt((m + x) * half);
229+
const realT res_im =
230+
(d == zero) ? std::copysign(zero, y) : y * half / d;
231+
return {d, res_im};
232+
}
233+
}
234+
235+
template <typename T> T scaling_threshold() const
236+
{
237+
if constexpr (std::is_same_v<T, float>) {
238+
return T(0x1.0p+126f);
239+
}
240+
else {
241+
return T(0x1.0p+1022);
242+
}
243+
}
244+
245+
template <typename T>
246+
std::complex<T> csqrt_finite(T const &x, T const &y) const
247+
{
248+
return (std::max<T>(std::abs(x), std::abs(y)) < scaling_threshold<T>())
249+
? csqrt_finite_unscaled(x, y)
250+
: csqrt_finite_scaled(x, y);
251+
}
173252
};
174253

175254
template <typename argTy,

0 commit comments

Comments
 (0)