Skip to content

Commit f55963d

Browse files
committed
[libc] Add implementation for hypotf
Truncating the sum of squares, and then use shift-and-add algorithm to compute its square root. Required MPFR testing infra is updated in https://reviews.llvm.org/D87514 Differential Revision: https://reviews.llvm.org/D87516
1 parent 2ffaa9a commit f55963d

File tree

9 files changed

+335
-0
lines changed

9 files changed

+335
-0
lines changed

libc/config/linux/aarch64/entrypoints.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ set(TARGET_LIBM_ENTRYPOINTS
6464
libc.src.math.frexp
6565
libc.src.math.frexpf
6666
libc.src.math.frexpl
67+
libc.src.math.hypotf
6768
libc.src.math.logb
6869
libc.src.math.logbf
6970
libc.src.math.logbl

libc/config/linux/api.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def MathAPI : PublicAPI<"math.h"> {
191191
"frexp",
192192
"frexpf",
193193
"frexpl",
194+
"hypotf",
194195
"logb",
195196
"logbf",
196197
"logbl",

libc/config/linux/x86_64/entrypoints.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ set(TARGET_LIBM_ENTRYPOINTS
9797
libc.src.math.frexp
9898
libc.src.math.frexpf
9999
libc.src.math.frexpl
100+
libc.src.math.hypotf
100101
libc.src.math.logb
101102
libc.src.math.logbf
102103
libc.src.math.logbl

libc/spec/stdc.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ def StdC : StandardSpec<"stdc"> {
296296
FunctionSpec<"frexpf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<IntPtr>]>,
297297
FunctionSpec<"frexpl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>, ArgSpec<IntPtr>]>,
298298

299+
FunctionSpec<"hypotf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<FloatType>]>,
300+
299301
FunctionSpec<"logb", RetValSpec<DoubleType>, [ArgSpec<DoubleType>]>,
300302
FunctionSpec<"logbf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,
301303
FunctionSpec<"logbl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>]>,

libc/src/math/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,3 +593,15 @@ add_entrypoint_object(
593593
COMPILE_OPTIONS
594594
-O2
595595
)
596+
597+
add_entrypoint_object(
598+
hypotf
599+
SRCS
600+
hypotf.cpp
601+
HDRS
602+
hypotf.h
603+
DEPENDS
604+
libc.utils.FPUtil.fputil
605+
COMPILE_OPTIONS
606+
-O2
607+
)

libc/src/math/hypotf.cpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
//===-- Implementation of hypotf function ---------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "src/__support/common.h"
9+
#include "utils/FPUtil/BasicOperations.h"
10+
#include "utils/FPUtil/FPBits.h"
11+
12+
namespace __llvm_libc {
13+
14+
using namespace fputil;
15+
16+
uint32_t findLeadingOne(uint32_t mant, int &shift_length) {
17+
shift_length = 0;
18+
constexpr int nsteps = 5;
19+
constexpr uint32_t bounds[nsteps] = {1 << 16, 1 << 8, 1 << 4, 1 << 2, 1 << 1};
20+
constexpr int shifts[nsteps] = {16, 8, 4, 2, 1};
21+
for (int i = 0; i < nsteps; ++i) {
22+
if (mant >= bounds[i]) {
23+
shift_length += shifts[i];
24+
mant >>= shifts[i];
25+
}
26+
}
27+
return 1U << shift_length;
28+
}
29+
30+
// Correctly rounded IEEE 754 HYPOT(x, y) with round to nearest, ties to even.
31+
//
32+
// Algorithm:
33+
// - Let a = max(|x|, |y|), b = min(|x|, |y|), then we have that:
34+
// a <= sqrt(a^2 + b^2) <= min(a + b, a*sqrt(2))
35+
// 1. So if b < eps(a)/2, then HYPOT(x, y) = a.
36+
//
37+
// - Moreover, the exponent part of HYPOT(x, y) is either the same or 1 more
38+
// than the exponent part of a.
39+
//
40+
// 2. For the remaining cases, we will use the digit-by-digit (shift-and-add)
41+
// algorithm to compute SQRT(Z):
42+
//
43+
// - For Y = y0.y1...yn... = SQRT(Z),
44+
// let Y(n) = y0.y1...yn be the first n fractional digits of Y.
45+
//
46+
// - The nth scaled residual R(n) is defined to be:
47+
// R(n) = 2^n * (Z - Y(n)^2)
48+
//
49+
// - Since Y(n) = Y(n - 1) + yn * 2^(-n), the scaled residual
50+
// satisfies the following recurrence formula:
51+
// R(n) = 2*R(n - 1) - yn*(2*Y(n - 1) + 2^(-n)),
52+
// with the initial conditions:
53+
// Y(0) = y0, and R(0) = Z - y0.
54+
//
55+
// - So the nth fractional digit of Y = SQRT(Z) can be decided by:
56+
// yn = 1 if 2*R(n - 1) >= 2*Y(n - 1) + 2^(-n),
57+
// 0 otherwise.
58+
//
59+
// 3. Precision analysis:
60+
//
61+
// - Notice that in the decision function:
62+
// 2*R(n - 1) >= 2*Y(n - 1) + 2^(-n),
63+
// the right hand side only uses up to the 2^(-n)-bit, and both sides are
64+
// non-negative, so R(n - 1) can be truncated at the 2^(-(n + 1))-bit, so
65+
// that 2*R(n - 1) is corrected up to the 2^(-n)-bit.
66+
//
67+
// - Thus, in order to round SQRT(a^2 + b^2) correctly up to n-fractional
68+
// bits, we need to perform the summation (a^2 + b^2) correctly up to (2n +
69+
// 2)-fractional bits, and the remaining bits are sticky bits (i.e. we only
70+
// care if they are 0 or > 0), and the comparisons, additions/subtractions
71+
// can be done in n-fractional bits precision.
72+
//
73+
// - For single precision (float), we can use uint64_t to store the sum a^2 +
74+
// b^2 exact up to (2n + 2)-fractional bits.
75+
//
76+
// - Then we can feed this sum into the digit-by-digit algorithm for SQRT(Z)
77+
// described above.
78+
//
79+
//
80+
// Special cases:
81+
// - HYPOT(x, y) is +Inf if x or y is +Inf or -Inf; else
82+
// - HYPOT(x, y) is NaN if x or y is NaN.
83+
//
84+
float LLVM_LIBC_ENTRYPOINT(hypotf)(float x, float y) {
85+
FPBits<float> x_bits(x), y_bits(y);
86+
87+
if (x_bits.isInf() || y_bits.isInf()) {
88+
return FPBits<float>::inf();
89+
}
90+
if (x_bits.isNaN()) {
91+
return x;
92+
}
93+
if (y_bits.isNaN()) {
94+
return y;
95+
}
96+
97+
uint16_t a_exp, b_exp, out_exp;
98+
uint32_t a_mant, b_mant;
99+
uint64_t a_mant_sq, b_mant_sq;
100+
bool sticky_bits;
101+
102+
if ((x_bits.exponent >= y_bits.exponent + MantissaWidth<float>::value + 2) ||
103+
(y == 0)) {
104+
return abs(x);
105+
} else if ((y_bits.exponent >=
106+
x_bits.exponent + MantissaWidth<float>::value + 2) ||
107+
(x == 0)) {
108+
y_bits.sign = 0;
109+
return abs(y);
110+
}
111+
112+
if (x >= y) {
113+
a_exp = x_bits.exponent;
114+
a_mant = x_bits.mantissa;
115+
b_exp = y_bits.exponent;
116+
b_mant = y_bits.mantissa;
117+
} else {
118+
a_exp = y_bits.exponent;
119+
a_mant = y_bits.mantissa;
120+
b_exp = x_bits.exponent;
121+
b_mant = x_bits.mantissa;
122+
}
123+
124+
out_exp = a_exp;
125+
126+
// Add an extra bit to simplify the final rounding bit computation.
127+
constexpr uint32_t one = 1U << (MantissaWidth<float>::value + 1);
128+
129+
a_mant <<= 1;
130+
b_mant <<= 1;
131+
132+
uint32_t leading_one;
133+
int y_mant_width;
134+
if (a_exp != 0) {
135+
leading_one = one;
136+
a_mant |= one;
137+
y_mant_width = MantissaWidth<float>::value + 1;
138+
} else {
139+
leading_one = findLeadingOne(a_mant, y_mant_width);
140+
}
141+
142+
if (b_exp != 0) {
143+
b_mant |= one;
144+
}
145+
146+
a_mant_sq = static_cast<uint64_t>(a_mant) * a_mant;
147+
b_mant_sq = static_cast<uint64_t>(b_mant) * b_mant;
148+
149+
// At this point, a_exp >= b_exp > a_exp - 25, so in order to line up aSqMant
150+
// and bSqMant, we need to shift bSqMant to the right by (a_exp - b_exp) bits.
151+
// But before that, remember to store the losing bits to sticky.
152+
// The shift length is for a^2 and b^2, so it's double of the exponent
153+
// difference between a and b.
154+
uint16_t shift_length = 2 * (a_exp - b_exp);
155+
sticky_bits = ((b_mant_sq & ((1ULL << shift_length) - 1)) != 0);
156+
b_mant_sq >>= shift_length;
157+
158+
uint64_t sum = a_mant_sq + b_mant_sq;
159+
if (sum >= (1ULL << (2 * y_mant_width + 2))) {
160+
// a^2 + b^2 >= 4* leading_one^2, so we will need an extra bit to the left.
161+
if (leading_one == one) {
162+
// For normal result, we discard the last 2 bits of the sum and increase
163+
// the exponent.
164+
sticky_bits = sticky_bits || ((sum & 0x3U) != 0);
165+
sum >>= 2;
166+
++out_exp;
167+
if (out_exp >= FPBits<float>::maxExponent) {
168+
return FPBits<float>::inf();
169+
}
170+
} else {
171+
// For denormal result, we simply move the leading bit of the result to
172+
// the left by 1.
173+
leading_one <<= 1;
174+
++y_mant_width;
175+
}
176+
}
177+
178+
uint32_t Y = leading_one;
179+
uint32_t R = static_cast<uint32_t>(sum >> y_mant_width) - leading_one;
180+
uint32_t tailBits = static_cast<uint32_t>(sum) & (leading_one - 1);
181+
182+
for (uint32_t current_bit = leading_one >> 1; current_bit;
183+
current_bit >>= 1) {
184+
R = (R << 1) + ((tailBits & current_bit) ? 1 : 0);
185+
uint32_t tmp = (Y << 1) + current_bit; // 2*y(n - 1) + 2^(-n)
186+
if (R >= tmp) {
187+
R -= tmp;
188+
Y += current_bit;
189+
}
190+
}
191+
192+
bool round_bit = Y & 1U;
193+
bool lsb = Y & 2U;
194+
195+
if (Y >= one) {
196+
Y -= one;
197+
198+
if (out_exp == 0) {
199+
out_exp = 1;
200+
}
201+
}
202+
203+
Y >>= 1;
204+
205+
// Round to the nearest, tie to even.
206+
if (round_bit && (lsb || sticky_bits || (R != 0))) {
207+
++Y;
208+
}
209+
210+
if (Y >= (one >> 1)) {
211+
Y -= one >> 1;
212+
++out_exp;
213+
if (out_exp >= FPBits<float>::maxExponent) {
214+
return FPBits<float>::inf();
215+
}
216+
}
217+
218+
Y |= static_cast<uint32_t>(out_exp) << MantissaWidth<float>::value;
219+
return *reinterpret_cast<float *>(&Y);
220+
}
221+
222+
} // namespace __llvm_libc

libc/src/math/hypotf.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===-- Implementation header for hypotf ------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_LIBC_SRC_MATH_HYPOTF_H
10+
#define LLVM_LIBC_SRC_MATH_HYPOTF_H
11+
12+
namespace __llvm_libc {
13+
14+
float hypotf(float x, float y);
15+
16+
} // namespace __llvm_libc
17+
18+
#endif // LLVM_LIBC_SRC_MATH_HYPOTF_H

libc/test/src/math/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,16 @@ add_fp_unittest(
591591
libc.src.math.remquol
592592
libc.utils.FPUtil.fputil
593593
)
594+
595+
add_fp_unittest(
596+
hypotf_test
597+
NEED_MPFR
598+
SUITE
599+
libc_math_unittests
600+
SRCS
601+
hypotf_test.cpp
602+
DEPENDS
603+
libc.include.math
604+
libc.src.math.hypotf
605+
libc.utils.FPUtil.fputil
606+
)

libc/test/src/math/hypotf_test.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===-- Unittests for hypotf ----------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "include/math.h"
10+
#include "src/math/hypotf.h"
11+
#include "utils/FPUtil/FPBits.h"
12+
#include "utils/FPUtil/TestHelpers.h"
13+
#include "utils/MPFRWrapper/MPFRUtils.h"
14+
#include "utils/UnitTest/Test.h"
15+
16+
using FPBits = __llvm_libc::fputil::FPBits<float>;
17+
using UIntType = FPBits::UIntType;
18+
19+
namespace mpfr = __llvm_libc::testing::mpfr;
20+
21+
static const float zero = FPBits::zero();
22+
static const float negZero = FPBits::negZero();
23+
static const float nan = FPBits::buildNaN(1);
24+
static const float inf = FPBits::inf();
25+
static const float negInf = FPBits::negInf();
26+
27+
TEST(HypotfTest, SpecialNumbers) {
28+
EXPECT_FP_EQ(__llvm_libc::hypotf(inf, nan), inf);
29+
EXPECT_FP_EQ(__llvm_libc::hypotf(nan, negInf), inf);
30+
EXPECT_FP_EQ(__llvm_libc::hypotf(zero, inf), inf);
31+
EXPECT_FP_EQ(__llvm_libc::hypotf(negInf, negZero), inf);
32+
33+
EXPECT_FP_EQ(__llvm_libc::hypotf(nan, nan), nan);
34+
EXPECT_FP_EQ(__llvm_libc::hypotf(nan, zero), nan);
35+
EXPECT_FP_EQ(__llvm_libc::hypotf(negZero, nan), nan);
36+
37+
EXPECT_FP_EQ(__llvm_libc::hypotf(negZero, zero), zero);
38+
}
39+
40+
TEST(HypotfTest, SubnormalRange) {
41+
constexpr UIntType count = 1000001;
42+
constexpr UIntType step =
43+
(FPBits::maxSubnormal - FPBits::minSubnormal) / count;
44+
for (UIntType v = FPBits::minSubnormal, w = FPBits::maxSubnormal;
45+
v <= FPBits::maxSubnormal && w >= FPBits::minSubnormal;
46+
v += step, w -= step) {
47+
float x = FPBits(v), y = FPBits(w);
48+
float result = __llvm_libc::hypotf(x, y);
49+
mpfr::BinaryInput<float> input{x, y};
50+
ASSERT_MPFR_MATCH(mpfr::Operation::Hypot, input, result, 0.5);
51+
}
52+
}
53+
54+
TEST(HypotfTest, NormalRange) {
55+
constexpr UIntType count = 1000001;
56+
constexpr UIntType step = (FPBits::maxNormal - FPBits::minNormal) / count;
57+
for (UIntType v = FPBits::minNormal, w = FPBits::maxNormal;
58+
v <= FPBits::maxNormal && w >= FPBits::minNormal; v += step, w -= step) {
59+
float x = FPBits(v), y = FPBits(w);
60+
float result = __llvm_libc::hypotf(x, y);
61+
;
62+
mpfr::BinaryInput<float> input{x, y};
63+
ASSERT_MPFR_MATCH(mpfr::Operation::Hypot, input, result, 0.5);
64+
}
65+
}

0 commit comments

Comments
 (0)