@@ -137,8 +137,42 @@ template <typename argT, typename resT> struct SqrtFunctor
137
137
}
138
138
}
139
139
140
+ int get_normal_scale_float (const float &v) const
141
+ {
142
+ constexpr int float_significant_bits = 23 ;
143
+ constexpr std::uint32_t exponent_mask = 0xff ;
144
+ constexpr int exponent_bias = 127 ;
145
+ const int scale = static_cast <int >(
146
+ (sycl::bit_cast<std::uint32_t >(v) >> float_significant_bits) &
147
+ exponent_mask);
148
+ return scale - exponent_bias;
149
+ }
150
+
151
+ int get_normal_scale_double (const double &v) const
152
+ {
153
+ constexpr int float_significant_bits = 53 ;
154
+ constexpr std::uint64_t exponent_mask = 0x7ff ;
155
+ constexpr int exponent_bias = 1023 ;
156
+ const int scale = static_cast <int >(
157
+ (sycl::bit_cast<std::uint64_t >(v) >> float_significant_bits) &
158
+ exponent_mask);
159
+ return scale - exponent_bias;
160
+ }
161
+
162
+ template <typename T> int get_normal_scale (const T &v) const
163
+ {
164
+ static_assert (std::is_same_v<T, float > || std::is_same_v<T, double >);
165
+
166
+ if constexpr (std::is_same_v<T, float >) {
167
+ return get_normal_scale_float (v);
168
+ }
169
+ else {
170
+ return get_normal_scale_double (v);
171
+ }
172
+ }
173
+
140
174
template <typename T>
141
- std::complex<T> csqrt_finite (T const &x, T const &y) const
175
+ std::complex<T> csqrt_finite_scaled (T const &x, T const &y) const
142
176
{
143
177
// csqrt(x + y*1j) =
144
178
// sqrt((cabs(x, y) + x) / 2) +
@@ -148,8 +182,8 @@ template <typename argT, typename resT> struct SqrtFunctor
148
182
constexpr realT half = realT (0x1 .0p-1f ); // 1/2
149
183
constexpr realT zero = realT (0 );
150
184
151
- const int exp_x = sycl::ilogb (x);
152
- const int exp_y = sycl::ilogb (y);
185
+ const int exp_x = get_normal_scale<realT> (x);
186
+ const int exp_y = get_normal_scale<realT> (y);
153
187
154
188
int sc = std::max<int >(exp_x, exp_y) / 2 ;
155
189
const realT xx = sycl::ldexp (x, -sc * 2 );
@@ -170,6 +204,51 @@ template <typename argT, typename resT> struct SqrtFunctor
170
204
return {sycl::ldexp (d, sc), sycl::ldexp (res_im, sc)};
171
205
}
172
206
}
207
+
208
+ template <typename T>
209
+ std::complex<T> csqrt_finite_unscaled (T const &x, T const &y) const
210
+ {
211
+ // csqrt(x + y*1j) =
212
+ // sqrt((cabs(x, y) + x) / 2) +
213
+ // 1j * copysign(sqrt((cabs(x, y) - x) / 2), y)
214
+
215
+ using realT = T;
216
+ constexpr realT half = realT (0x1 .0p-1f ); // 1/2
217
+ constexpr realT zero = realT (0 );
218
+
219
+ if (std::signbit (x)) {
220
+ const realT m = std::hypot (x, y);
221
+ const realT d = std::sqrt ((m - x) * half);
222
+ const realT res_re = (d == zero ? zero : std::abs (y) / d * half);
223
+ const realT res_im = std::copysign (d, y);
224
+ return {res_re, res_im};
225
+ }
226
+ else {
227
+ const realT m = std::hypot (x, y);
228
+ const realT d = std::sqrt ((m + x) * half);
229
+ const realT res_im =
230
+ (d == zero) ? std::copysign (zero, y) : y * half / d;
231
+ return {d, res_im};
232
+ }
233
+ }
234
+
235
+ template <typename T> T scaling_threshold () const
236
+ {
237
+ if constexpr (std::is_same_v<T, float >) {
238
+ return T (0x1 .0p+126f );
239
+ }
240
+ else {
241
+ return T (0x1 .0p+1022 );
242
+ }
243
+ }
244
+
245
+ template <typename T>
246
+ std::complex<T> csqrt_finite (T const &x, T const &y) const
247
+ {
248
+ return (std::max<T>(std::abs (x), std::abs (y)) < scaling_threshold<T>())
249
+ ? csqrt_finite_unscaled (x, y)
250
+ : csqrt_finite_scaled (x, y);
251
+ }
173
252
};
174
253
175
254
template <typename argTy,
0 commit comments