Skip to content

Commit 3b9b235

Browse files
lntueyuxuanchen1997
authored andcommitted
[libc][math] Fix signaling nan handling of hypot(f) and improve hypotf performance. (#99432)
Summary: The errors were reported by Paul Zimmermann with the CORE-MATH project's test suites: ``` zimmerma@tartine:/tmp/core-math$ CORE_MATH_CHECK_STD=true LIBM=$L ./check.sh hypot Running worst cases check in --rndn mode... FAIL x=snan y=inf ref=qnan z=inf Running worst cases check in --rndz mode... FAIL x=snan y=inf ref=qnan z=inf Running worst cases check in --rndu mode... FAIL x=snan y=inf ref=qnan z=inf Running worst cases check in --rndd mode... Spurious inexact exception for x=0x1.ffffffffffffep+24 y=0x1p+0 (z=0x1.0000000000001p+25) ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251275
1 parent 3368287 commit 3b9b235

File tree

6 files changed

+110
-94
lines changed

6 files changed

+110
-94
lines changed

libc/src/__support/FPUtil/Hypot.h

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -109,45 +109,39 @@ LIBC_INLINE T hypot(T x, T y) {
109109
using StorageType = typename FPBits<T>::StorageType;
110110
using DStorageType = typename DoubleLength<StorageType>::Type;
111111

112-
FPBits_t x_bits(x), y_bits(y);
112+
FPBits_t x_abs = FPBits_t(x).abs();
113+
FPBits_t y_abs = FPBits_t(y).abs();
113114

114-
if (x_bits.is_inf() || y_bits.is_inf()) {
115-
return FPBits_t::inf().get_val();
116-
}
117-
if (x_bits.is_nan()) {
118-
return x;
119-
}
120-
if (y_bits.is_nan()) {
115+
bool x_abs_larger = x_abs.uintval() >= y_abs.uintval();
116+
117+
FPBits_t a_bits = x_abs_larger ? x_abs : y_abs;
118+
FPBits_t b_bits = x_abs_larger ? y_abs : x_abs;
119+
120+
if (LIBC_UNLIKELY(a_bits.is_inf_or_nan())) {
121+
if (x_abs.is_signaling_nan() || y_abs.is_signaling_nan()) {
122+
fputil::raise_except_if_required(FE_INVALID);
123+
return FPBits_t::quiet_nan().get_val();
124+
}
125+
if (x_abs.is_inf() || y_abs.is_inf())
126+
return FPBits_t::inf().get_val();
127+
if (x_abs.is_nan())
128+
return x;
129+
// y is nan
121130
return y;
122131
}
123132

124-
uint16_t x_exp = x_bits.get_biased_exponent();
125-
uint16_t y_exp = y_bits.get_biased_exponent();
126-
uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp);
133+
uint16_t a_exp = a_bits.get_biased_exponent();
134+
uint16_t b_exp = b_bits.get_biased_exponent();
127135

128-
if ((exp_diff >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0)) {
129-
return abs(x) + abs(y);
130-
}
136+
if ((a_exp - b_exp >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0))
137+
return x_abs.get_val() + y_abs.get_val();
131138

132-
uint16_t a_exp, b_exp, out_exp;
133-
StorageType a_mant, b_mant;
139+
uint64_t out_exp = a_exp;
140+
StorageType a_mant = a_bits.get_mantissa();
141+
StorageType b_mant = b_bits.get_mantissa();
134142
DStorageType a_mant_sq, b_mant_sq;
135143
bool sticky_bits;
136144

137-
if (abs(x) >= abs(y)) {
138-
a_exp = x_exp;
139-
a_mant = x_bits.get_mantissa();
140-
b_exp = y_exp;
141-
b_mant = y_bits.get_mantissa();
142-
} else {
143-
a_exp = y_exp;
144-
a_mant = y_bits.get_mantissa();
145-
b_exp = x_exp;
146-
b_mant = x_bits.get_mantissa();
147-
}
148-
149-
out_exp = a_exp;
150-
151145
// Add an extra bit to simplify the final rounding bit computation.
152146
constexpr StorageType ONE = StorageType(1) << (FPBits_t::FRACTION_LEN + 1);
153147

@@ -165,11 +159,10 @@ LIBC_INLINE T hypot(T x, T y) {
165159
a_exp = 1;
166160
}
167161

168-
if (b_exp != 0) {
162+
if (b_exp != 0)
169163
b_mant |= ONE;
170-
} else {
164+
else
171165
b_exp = 1;
172-
}
173166

174167
a_mant_sq = static_cast<DStorageType>(a_mant) * a_mant;
175168
b_mant_sq = static_cast<DStorageType>(b_mant) * b_mant;
@@ -260,6 +253,10 @@ LIBC_INLINE T hypot(T x, T y) {
260253
}
261254

262255
y_new |= static_cast<StorageType>(out_exp) << FPBits_t::FRACTION_LEN;
256+
257+
if (!(round_bit || sticky_bits || (r != 0)))
258+
fputil::clear_except_if_required(FE_INEXACT);
259+
263260
return cpp::bit_cast<T>(y_new);
264261
}
265262

libc/src/math/generic/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2830,9 +2830,12 @@ add_entrypoint_object(
28302830
HDRS
28312831
../hypotf.h
28322832
DEPENDS
2833-
libc.src.__support.FPUtil.basic_operations
2833+
libc.src.__support.FPUtil.double_double
2834+
libc.src.__support.FPUtil.fenv_impl
28342835
libc.src.__support.FPUtil.fp_bits
2836+
libc.src.__support.FPUtil.multiply_add
28352837
libc.src.__support.FPUtil.sqrt
2838+
libc.src.__support.macros.optimization
28362839
COMPILE_OPTIONS
28372840
-O3
28382841
)

libc/src/math/generic/hypotf.cpp

Lines changed: 62 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,66 +6,90 @@
66
//
77
//===----------------------------------------------------------------------===//
88
#include "src/math/hypotf.h"
9-
#include "src/__support/FPUtil/BasicOperations.h"
9+
#include "src/__support/FPUtil/FEnvImpl.h"
1010
#include "src/__support/FPUtil/FPBits.h"
11+
#include "src/__support/FPUtil/double_double.h"
12+
#include "src/__support/FPUtil/multiply_add.h"
1113
#include "src/__support/FPUtil/sqrt.h"
1214
#include "src/__support/common.h"
1315
#include "src/__support/macros/config.h"
16+
#include "src/__support/macros/optimization.h"
1417

1518
namespace LIBC_NAMESPACE_DECL {
1619

1720
LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
1821
using DoubleBits = fputil::FPBits<double>;
1922
using FPBits = fputil::FPBits<float>;
2023

21-
FPBits x_bits(x), y_bits(y);
24+
FPBits x_abs = FPBits(x).abs();
25+
FPBits y_abs = FPBits(y).abs();
2226

23-
uint16_t x_exp = x_bits.get_biased_exponent();
24-
uint16_t y_exp = y_bits.get_biased_exponent();
25-
uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp);
27+
bool x_abs_larger = x_abs.uintval() >= y_abs.uintval();
2628

27-
if (exp_diff >= FPBits::FRACTION_LEN + 2) {
28-
return fputil::abs(x) + fputil::abs(y);
29-
}
29+
FPBits a_bits = x_abs_larger ? x_abs : y_abs;
30+
FPBits b_bits = x_abs_larger ? y_abs : x_abs;
3031

31-
double xd = static_cast<double>(x);
32-
double yd = static_cast<double>(y);
32+
uint32_t a_u = a_bits.uintval();
33+
uint32_t b_u = b_bits.uintval();
3334

34-
// These squares are exact.
35-
double x_sq = xd * xd;
36-
double y_sq = yd * yd;
35+
// Note: replacing `a_u >= FPBits::EXP_MASK` with `a_bits.is_inf_or_nan()`
36+
// generates extra exponent bit masking instructions on x86-64.
37+
if (LIBC_UNLIKELY(a_u >= FPBits::EXP_MASK)) {
38+
// x or y is inf or nan
39+
if (a_bits.is_signaling_nan() || b_bits.is_signaling_nan()) {
40+
fputil::raise_except_if_required(FE_INVALID);
41+
return FPBits::quiet_nan().get_val();
42+
}
43+
if (a_bits.is_inf() || b_bits.is_inf())
44+
return FPBits::inf().get_val();
45+
return a_bits.get_val();
46+
}
3747

38-
// Compute the sum of squares.
39-
double sum_sq = x_sq + y_sq;
48+
if (LIBC_UNLIKELY(a_u - b_u >=
49+
static_cast<uint32_t>((FPBits::FRACTION_LEN + 2)
50+
<< FPBits::FRACTION_LEN)))
51+
return x_abs.get_val() + y_abs.get_val();
4052

41-
// Compute the rounding error with Fast2Sum algorithm:
42-
// x_sq + y_sq = sum_sq - err
43-
double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;
53+
double ad = static_cast<double>(a_bits.get_val());
54+
double bd = static_cast<double>(b_bits.get_val());
55+
56+
// These squares are exact.
57+
double a_sq = ad * ad;
58+
#ifdef LIBC_TARGET_CPU_HAS_FMA
59+
double sum_sq = fputil::multiply_add(bd, bd, a_sq);
60+
#else
61+
double b_sq = bd * bd;
62+
double sum_sq = a_sq + b_sq;
63+
#endif
4464

4565
// Take sqrt in double precision.
4666
DoubleBits result(fputil::sqrt<double>(sum_sq));
67+
uint64_t r_u = result.uintval();
4768

48-
if (!DoubleBits(sum_sq).is_inf_or_nan()) {
49-
// Correct rounding.
50-
double r_sq = result.get_val() * result.get_val();
51-
double diff = sum_sq - r_sq;
52-
constexpr uint64_t MASK = 0x0000'0000'3FFF'FFFFULL;
53-
uint64_t lrs = result.uintval() & MASK;
54-
55-
if (lrs == 0x0000'0000'1000'0000ULL && err < diff) {
56-
result.set_uintval(result.uintval() | 1ULL);
57-
} else if (lrs == 0x0000'0000'3000'0000ULL && err > diff) {
58-
result.set_uintval(result.uintval() - 1ULL);
59-
}
60-
} else {
61-
FPBits bits_x(x), bits_y(y);
62-
if (bits_x.is_inf_or_nan() || bits_y.is_inf_or_nan()) {
63-
if (bits_x.is_inf() || bits_y.is_inf())
64-
return FPBits::inf().get_val();
65-
if (bits_x.is_nan())
66-
return x;
67-
return y;
69+
// If any of the sticky bits of the result are non-zero, except the LSB, then
70+
// the rounded result is correct.
71+
if (LIBC_UNLIKELY(((r_u + 1) & 0x0000'0000'0FFF'FFFE) == 0)) {
72+
double r_d = result.get_val();
73+
74+
// Perform rounding correction.
75+
#ifdef LIBC_TARGET_CPU_HAS_FMA
76+
double sum_sq_lo = fputil::multiply_add(bd, bd, a_sq - sum_sq);
77+
double err = sum_sq_lo - fputil::multiply_add(r_d, r_d, -sum_sq);
78+
#else
79+
fputil::DoubleDouble r_sq = fputil::exact_mult(r_d, r_d);
80+
double sum_sq_lo = b_sq - (sum_sq - a_sq);
81+
double err = (sum_sq - r_sq.hi) + (sum_sq_lo - r_sq.lo);
82+
#endif
83+
84+
if (err > 0) {
85+
r_u |= 1;
86+
} else if ((err < 0) && (r_u & 1) == 0) {
87+
r_u -= 1;
88+
} else if ((r_u & 0x0000'0000'1FFF'FFFF) == 0) {
89+
// The rounded result is exact.
90+
fputil::clear_except_if_required(FE_INEXACT);
6891
}
92+
return static_cast<float>(DoubleBits(r_u).get_val());
6993
}
7094

7195
return static_cast<float>(result.get_val());

libc/test/src/math/smoke/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2733,6 +2733,8 @@ add_fp_unittest(
27332733
libc-math-smoke-tests
27342734
SRCS
27352735
hypotf_test.cpp
2736+
HDRS
2737+
HypotTest.h
27362738
DEPENDS
27372739
libc.src.math.hypotf
27382740
libc.src.__support.FPUtil.fp_bits
@@ -2744,6 +2746,8 @@ add_fp_unittest(
27442746
libc-math-smoke-tests
27452747
SRCS
27462748
hypot_test.cpp
2749+
HDRS
2750+
HypotTest.h
27472751
DEPENDS
27482752
libc.src.math.hypot
27492753
libc.src.__support.FPUtil.fp_bits

libc/test/src/math/smoke/HypotTest.h

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,42 +9,29 @@
99
#ifndef LLVM_LIBC_TEST_SRC_MATH_HYPOTTEST_H
1010
#define LLVM_LIBC_TEST_SRC_MATH_HYPOTTEST_H
1111

12-
#include "src/__support/FPUtil/FPBits.h"
13-
#include "test/UnitTest/FEnvSafeTest.h"
1412
#include "test/UnitTest/FPMatcher.h"
1513
#include "test/UnitTest/Test.h"
1614

17-
#include "hdr/math_macros.h"
18-
1915
template <typename T>
20-
class HypotTestTemplate : public LIBC_NAMESPACE::testing::FEnvSafeTest {
16+
class HypotTestTemplate : public LIBC_NAMESPACE::testing::Test {
2117
private:
2218
using Func = T (*)(T, T);
23-
using FPBits = LIBC_NAMESPACE::fputil::FPBits<T>;
24-
using StorageType = typename FPBits::StorageType;
25-
26-
const T nan = FPBits::quiet_nan().get_val();
27-
const T inf = FPBits::inf(Sign::POS).get_val();
28-
const T neg_inf = FPBits::inf(Sign::NEG).get_val();
29-
const T zero = FPBits::zero(Sign::POS).get_val();
30-
const T neg_zero = FPBits::zero(Sign::NEG).get_val();
3119

32-
const T max_normal = FPBits::max_normal().get_val();
33-
const T min_normal = FPBits::min_normal().get_val();
34-
const T max_subnormal = FPBits::max_subnormal().get_val();
35-
const T min_subnormal = FPBits::min_subnormal().get_val();
20+
DECLARE_SPECIAL_CONSTANTS(T)
3621

3722
public:
3823
void test_special_numbers(Func func) {
3924
constexpr int N = 4;
4025
// Pythagorean triples.
4126
constexpr T PYT[N][3] = {{3, 4, 5}, {5, 12, 13}, {8, 15, 17}, {7, 24, 25}};
4227

43-
EXPECT_FP_EQ(func(inf, nan), inf);
44-
EXPECT_FP_EQ(func(nan, neg_inf), inf);
45-
EXPECT_FP_EQ(func(nan, nan), nan);
46-
EXPECT_FP_EQ(func(nan, zero), nan);
47-
EXPECT_FP_EQ(func(neg_zero, nan), nan);
28+
EXPECT_FP_EQ(func(inf, sNaN), aNaN);
29+
EXPECT_FP_EQ(func(sNaN, neg_inf), aNaN);
30+
EXPECT_FP_EQ(func(inf, aNaN), inf);
31+
EXPECT_FP_EQ(func(aNaN, neg_inf), inf);
32+
EXPECT_FP_EQ(func(aNaN, aNaN), aNaN);
33+
EXPECT_FP_EQ(func(aNaN, zero), aNaN);
34+
EXPECT_FP_EQ(func(neg_zero, aNaN), aNaN);
4835

4936
for (int i = 0; i < N; ++i) {
5037
EXPECT_FP_EQ_ALL_ROUNDING(PYT[i][2], func(PYT[i][0], PYT[i][1]));

utils/bazel/llvm-project-overlay/libc/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,6 +2020,7 @@ libc_math_function(name = "hypot")
20202020
libc_math_function(
20212021
name = "hypotf",
20222022
additional_deps = [
2023+
":__support_fputil_double_double",
20232024
":__support_fputil_sqrt",
20242025
],
20252026
)

0 commit comments

Comments
 (0)