1
1
// ===-- Implementation of sqrtf128 function -------------------------------===//
2
2
//
3
- // Copyright (c) 2024 Alexei Sibidanov <[email protected] >
4
- //
5
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6
4
// See https://llvm.org/LICENSE.txt for license information.
7
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
17
15
#include " src/__support/macros/optimization.h"
18
16
#include " src/__support/uint128.h"
19
17
18
+ // Compute sqrtf128 with correct rounding for all rounding modes using integer
19
+ // arithmetic by Alexei Sibidanov ([email protected] ):
20
+ // Let the input be expressed as x = 2^e * m_x,
21
+ // - Step 1: Range reduction
22
+ // Let x_reduced = 2^(e % 2) * m_x,
23
+ // Then sqrt(x) = 2^(e / 2) * sqrt(x_reduced), with
24
+ // 1 <= x_reduced < 4.
25
+ // - Step 2: Polynomial approximation
26
+ // Approximate 1/sqrt(x_reduced) using polynomial approximation with the
27
+ // result errors bounded by:
28
+ // |r0 - 1/sqrt(x_reduced)| < 2^-32.
29
+ // The computations are done in uint64_t.
30
+ // - Step 3: First Newton iteration
31
+ // Let the scaled error defined by:
32
+ // h0 = r0^2 * x_reduced - 1.
33
+ // Then we compute the first Newton iteration:
34
+ // r1 = r0 - r0 * h0 / 2.
35
+ // The result is then bounded by:
36
+ // |r1 - 1 / sqrt(x_reduced)| < 2^-62.
37
+ // - Step 4: Second Newton iteration
38
+ // We calculate the scaled error from Step 3:
39
+ // h1 = r1^2 * x_reduced - 1.
40
+ // Then the second Newton iteration is computed by:
41
+ // r2 = x_reduced * (r1 - r1 * h0 / 2)
42
+ // ~ x_reduced * (1/sqrt(x_reduced)) = sqrt(x_reduced)
43
+ // - Step 5: Perform rounding test and correction if needed.
44
+ // Rounding correction is done by computing the exact rounding errors:
45
+ // x_reduced - r2^2.
46
+
20
47
namespace LIBC_NAMESPACE_DECL {
21
48
22
49
using FPBits = fputil::FPBits<float128>;
@@ -35,11 +62,11 @@ inline constexpr uint64_t prod_hi<uint64_t>(uint64_t x, uint64_t y) {
35
62
36
63
// Get high part of unsigned 128x64 bit multiplication.
37
64
template <>
38
- inline constexpr UInt128 prod_hi<UInt128, uint64_t >(UInt128 y , uint64_t x ) {
39
- uint64_t y_lo = static_cast <uint64_t >(y );
40
- uint64_t y_hi = static_cast <uint64_t >(y >> 64 );
41
- UInt128 xyl = static_cast <UInt128>(x ) * static_cast <UInt128>(y_lo );
42
- UInt128 xyh = static_cast <UInt128>(x ) * static_cast <UInt128>(y_hi );
65
+ inline constexpr UInt128 prod_hi<UInt128, uint64_t >(UInt128 x , uint64_t y ) {
66
+ uint64_t x_lo = static_cast <uint64_t >(x );
67
+ uint64_t x_hi = static_cast <uint64_t >(x >> 64 );
68
+ UInt128 xyl = static_cast <UInt128>(x_lo ) * static_cast <UInt128>(y );
69
+ UInt128 xyh = static_cast <UInt128>(x_hi ) * static_cast <UInt128>(y );
43
70
return xyh + (xyl >> 64 );
44
71
}
45
72
@@ -178,11 +205,11 @@ LIBC_INLINE uint64_t rsqrt_approx(uint64_t m) {
178
205
// r1 = r0 - r0 * h / 2
179
206
// which has error bounded by:
180
207
// |r1 - 1/sqrt(x)| < h^2 / 2.
181
- uint64_t r2 = prod_hi< uint64_t > (r, r);
208
+ uint64_t r2 = prod_hi (r, r);
182
209
// h = r0^2*x - 1.
183
- int64_t h = static_cast <int64_t >(prod_hi< uint64_t > (m, r2) + r2);
210
+ int64_t h = static_cast <int64_t >(prod_hi (m, r2) + r2);
184
211
// hr = r * h / 2
185
- int64_t hr = prod_hi< int64_t > (h, static_cast <int64_t >(r >> 1 ));
212
+ int64_t hr = prod_hi (h, static_cast <int64_t >(r >> 1 ));
186
213
r -= hr;
187
214
// Adjust in the unlucky case x~1;
188
215
if (LIBC_UNLIKELY (!r))
@@ -224,8 +251,10 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
224
251
fputil::raise_except_if_required (FE_INVALID);
225
252
return xbits.quiet_nan ().get_val ();
226
253
}
227
- // x is subnormal or x=+0
228
- if (x == 0 )
254
+ // Now x is subnormal or x = +0.
255
+
256
+ // x is +0.
257
+ if (x_u == 0 )
229
258
return x;
230
259
231
260
// Normalize subnormal inputs.
@@ -253,7 +282,7 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
253
282
0xb504f333f9de6484 /* 2^64/sqrt(2) */ };
254
283
255
284
// Approximate 1/sqrt(1 + x_frac)
256
- // Error: |r_1 - 1/sqrt(x)| < 2^-63 .
285
+ // Error: |r_1 - 1/sqrt(x)| < 2^-62 .
257
286
uint64_t r1 = rsqrt_approx (static_cast <uint64_t >(x_frac >> 64 ));
258
287
// Adjust for the even/odd exponent.
259
288
uint64_t r2 = prod_hi (r1, RSQRT_2[i]);
@@ -279,8 +308,9 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
279
308
uint32_t nrst = rm == FE_TONEAREST;
280
309
// The result lies within (-2,5) of true square root so we now
281
310
// test that we can correctly round the result taking into account
282
- // the rounding mode
283
- // check the lowest 14 bits.
311
+ // the rounding mode.
312
+ // Check the lowest 14 bits (by clearing and sign-extending the top
313
+ // 32 - 14 = 18 bits).
284
314
int dd = (static_cast <int >(v) << 18 ) >> 18 ;
285
315
286
316
if (LIBC_UNLIKELY (dd < 4 && dd >= -8 )) { // can round correctly?
@@ -289,17 +319,16 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
289
319
// compare with the initial argument.
290
320
UInt128 m = v >> 15 ;
291
321
UInt128 m2 = m * m;
292
- Int128 t0, t1;
293
322
// The difference of the squared result and the argument
294
- t0 = static_cast <Int128>(m2 - (x_reduced << 98 ));
323
+ Int128 t0 = static_cast <Int128>(m2 - (x_reduced << 98 ));
295
324
if (t0 == 0 ) {
296
325
// the square root is exact
297
326
v = m << 15 ;
298
327
} else {
299
328
// Add +-1 ulp to m depend on the sign of the difference. Here
300
329
// we do not need to square again since (m+1)^2 = m^2 + 2*m +
301
330
// 1 so just need to add shifted m and 1.
302
- t1 = t0;
331
+ Int128 t1 = t0;
303
332
Int128 sgn = t0 >> 127 ; // sign of the difference
304
333
t1 -= (m << 1 ) ^ sgn;
305
334
t1 += 1 + sgn;
@@ -332,20 +361,20 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
332
361
rnd = frac >> 14 ; // round to nearest tie to even
333
362
} else if (rm == FE_UPWARD) {
334
363
rnd = !!frac; // round up
335
- } else if (rm == FE_DOWNWARD) {
336
- rnd = 0 ; // round down
337
364
} else {
338
- rnd = 0 ; // round to zero
365
+ rnd = 0 ; // round down or round to zero
339
366
}
340
367
341
368
v >>= 15 ; // position mantissa
342
369
v += rnd; // round
343
370
344
- // // Set inexact flag only if square root is inexact
345
- // // TODO: We will have to raise FE_INEXACT most of the time, but this
346
- // // operation is very costly, especially in x86-64, since technically, it
347
- // // needs to synchronize both SSE and x87 flags. Need to investigate
348
- // // further to see how we can make this performant.
371
+ // Set inexact flag only if square root is inexact
372
+ // TODO: We will have to raise FE_INEXACT most of the time, but this
373
+ // operation is very costly, especially in x86-64, since technically, it
374
+ // needs to synchronize both SSE and x87 flags. Need to investigate
375
+ // further to see how we can make this performant.
376
+ // https://github.com/llvm/llvm-project/issues/126753
377
+
349
378
// if(frac) fputil::raise_except_if_required(FE_INEXACT);
350
379
351
380
v += static_cast <UInt128>(e2 ) << FPBits::FRACTION_LEN; // place exponent
0 commit comments