Skip to content

Commit b4d2045

Browse files
committed
Address comments.
1 parent c1e3334 commit b4d2045

File tree

1 file changed

+55
-26
lines changed

1 file changed

+55
-26
lines changed

libc/src/math/generic/sqrtf128.cpp

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
//===-- Implementation of sqrtf128 function -------------------------------===//
22
//
3-
// Copyright (c) 2024 Alexei Sibidanov <[email protected]>
4-
//
53
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
64
// See https://llvm.org/LICENSE.txt for license information.
75
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -17,6 +15,35 @@
1715
#include "src/__support/macros/optimization.h"
1816
#include "src/__support/uint128.h"
1917

18+
// Compute sqrtf128 with correct rounding for all rounding modes using integer
19+
// arithmetic by Alexei Sibidanov ([email protected]):
20+
// Let the input be expressed as x = 2^e * m_x,
21+
// - Step 1: Range reduction
22+
// Let x_reduced = 2^(e % 2) * m_x,
23+
// Then sqrt(x) = 2^(e / 2) * sqrt(x_reduced), with
24+
// 1 <= x_reduced < 4.
25+
// - Step 2: Polynomial approximation
26+
// Approximate 1/sqrt(x_reduced) using polynomial approximation with the
27+
// result errors bounded by:
28+
// |r0 - 1/sqrt(x_reduced)| < 2^-32.
29+
// The computations are done in uint64_t.
30+
// - Step 3: First Newton iteration
31+
// Let the scaled error defined by:
32+
// h0 = r0^2 * x_reduced - 1.
33+
// Then we compute the first Newton iteration:
34+
// r1 = r0 - r0 * h0 / 2.
35+
// The result is then bounded by:
36+
// |r1 - 1 / sqrt(x_reduced)| < 2^-62.
37+
// - Step 4: Second Newton iteration
38+
// We calculate the scaled error from Step 3:
39+
// h1 = r1^2 * x_reduced - 1.
40+
// Then the second Newton iteration is computed by:
41+
// r2 = x_reduced * (r1 - r1 * h0 / 2)
42+
// ~ x_reduced * (1/sqrt(x_reduced)) = sqrt(x_reduced)
43+
// - Step 5: Perform rounding test and correction if needed.
44+
// Rounding correction is done by computing the exact rounding errors:
45+
// x_reduced - r2^2.
46+
2047
namespace LIBC_NAMESPACE_DECL {
2148

2249
using FPBits = fputil::FPBits<float128>;
@@ -35,11 +62,11 @@ inline constexpr uint64_t prod_hi<uint64_t>(uint64_t x, uint64_t y) {
3562

3663
// Get high part of unsigned 128x64 bit multiplication.
3764
template <>
38-
inline constexpr UInt128 prod_hi<UInt128, uint64_t>(UInt128 y, uint64_t x) {
39-
uint64_t y_lo = static_cast<uint64_t>(y);
40-
uint64_t y_hi = static_cast<uint64_t>(y >> 64);
41-
UInt128 xyl = static_cast<UInt128>(x) * static_cast<UInt128>(y_lo);
42-
UInt128 xyh = static_cast<UInt128>(x) * static_cast<UInt128>(y_hi);
65+
inline constexpr UInt128 prod_hi<UInt128, uint64_t>(UInt128 x, uint64_t y) {
66+
uint64_t x_lo = static_cast<uint64_t>(x);
67+
uint64_t x_hi = static_cast<uint64_t>(x >> 64);
68+
UInt128 xyl = static_cast<UInt128>(x_lo) * static_cast<UInt128>(y);
69+
UInt128 xyh = static_cast<UInt128>(x_hi) * static_cast<UInt128>(y);
4370
return xyh + (xyl >> 64);
4471
}
4572

@@ -178,11 +205,11 @@ LIBC_INLINE uint64_t rsqrt_approx(uint64_t m) {
178205
// r1 = r0 - r0 * h / 2
179206
// which has error bounded by:
180207
// |r1 - 1/sqrt(x)| < h^2 / 2.
181-
uint64_t r2 = prod_hi<uint64_t>(r, r);
208+
uint64_t r2 = prod_hi(r, r);
182209
// h = r0^2*x - 1.
183-
int64_t h = static_cast<int64_t>(prod_hi<uint64_t>(m, r2) + r2);
210+
int64_t h = static_cast<int64_t>(prod_hi(m, r2) + r2);
184211
// hr = r * h / 2
185-
int64_t hr = prod_hi<int64_t>(h, static_cast<int64_t>(r >> 1));
212+
int64_t hr = prod_hi(h, static_cast<int64_t>(r >> 1));
186213
r -= hr;
187214
// Adjust in the unlucky case x~1;
188215
if (LIBC_UNLIKELY(!r))
@@ -224,8 +251,10 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
224251
fputil::raise_except_if_required(FE_INVALID);
225252
return xbits.quiet_nan().get_val();
226253
}
227-
// x is subnormal or x=+0
228-
if (x == 0)
254+
// Now x is subnormal or x = +0.
255+
256+
// x is +0.
257+
if (x_u == 0)
229258
return x;
230259

231260
// Normalize subnormal inputs.
@@ -253,7 +282,7 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
253282
0xb504f333f9de6484 /* 2^64/sqrt(2) */};
254283

255284
// Approximate 1/sqrt(1 + x_frac)
256-
// Error: |r_1 - 1/sqrt(x)| < 2^-63.
285+
// Error: |r_1 - 1/sqrt(x)| < 2^-62.
257286
uint64_t r1 = rsqrt_approx(static_cast<uint64_t>(x_frac >> 64));
258287
// Adjust for the even/odd exponent.
259288
uint64_t r2 = prod_hi(r1, RSQRT_2[i]);
@@ -279,8 +308,9 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
279308
uint32_t nrst = rm == FE_TONEAREST;
280309
// The result lies within (-2,5) of true square root so we now
281310
// test that we can correctly round the result taking into account
282-
// the rounding mode
283-
// check the lowest 14 bits.
311+
// the rounding mode.
312+
// Check the lowest 14 bits (by clearing and sign-extending the top
313+
// 32 - 14 = 18 bits).
284314
int dd = (static_cast<int>(v) << 18) >> 18;
285315

286316
if (LIBC_UNLIKELY(dd < 4 && dd >= -8)) { // can round correctly?
@@ -289,17 +319,16 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
289319
// compare with the initial argument.
290320
UInt128 m = v >> 15;
291321
UInt128 m2 = m * m;
292-
Int128 t0, t1;
293322
// The difference of the squared result and the argument
294-
t0 = static_cast<Int128>(m2 - (x_reduced << 98));
323+
Int128 t0 = static_cast<Int128>(m2 - (x_reduced << 98));
295324
if (t0 == 0) {
296325
// the square root is exact
297326
v = m << 15;
298327
} else {
299328
// Add +-1 ulp to m depend on the sign of the difference. Here
300329
// we do not need to square again since (m+1)^2 = m^2 + 2*m +
301330
// 1 so just need to add shifted m and 1.
302-
t1 = t0;
331+
Int128 t1 = t0;
303332
Int128 sgn = t0 >> 127; // sign of the difference
304333
t1 -= (m << 1) ^ sgn;
305334
t1 += 1 + sgn;
@@ -332,20 +361,20 @@ LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
332361
rnd = frac >> 14; // round to nearest tie to even
333362
} else if (rm == FE_UPWARD) {
334363
rnd = !!frac; // round up
335-
} else if (rm == FE_DOWNWARD) {
336-
rnd = 0; // round down
337364
} else {
338-
rnd = 0; // round to zero
365+
rnd = 0; // round down or round to zero
339366
}
340367

341368
v >>= 15; // position mantissa
342369
v += rnd; // round
343370

344-
// // Set inexact flag only if square root is inexact
345-
// // TODO: We will have to raise FE_INEXACT most of the time, but this
346-
// // operation is very costly, especially in x86-64, since technically, it
347-
// // needs to synchronize both SSE and x87 flags. Need to investigate
348-
// // further to see how we can make this performant.
371+
// Set inexact flag only if square root is inexact
372+
// TODO: We will have to raise FE_INEXACT most of the time, but this
373+
// operation is very costly, especially in x86-64, since technically, it
374+
// needs to synchronize both SSE and x87 flags. Need to investigate
375+
// further to see how we can make this performant.
376+
// https://github.com/llvm/llvm-project/issues/126753
377+
349378
// if(frac) fputil::raise_except_if_required(FE_INEXACT);
350379

351380
v += static_cast<UInt128>(e2) << FPBits::FRACTION_LEN; // place exponent

0 commit comments

Comments
 (0)