Skip to content

[libc][math] Fix signaling nan handling of hypot(f) and improve hypotf performance. #99432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 30 additions & 33 deletions libc/src/__support/FPUtil/Hypot.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,45 +109,39 @@ LIBC_INLINE T hypot(T x, T y) {
using StorageType = typename FPBits<T>::StorageType;
using DStorageType = typename DoubleLength<StorageType>::Type;

FPBits_t x_bits(x), y_bits(y);
FPBits_t x_abs = FPBits_t(x).abs();
FPBits_t y_abs = FPBits_t(y).abs();

if (x_bits.is_inf() || y_bits.is_inf()) {
return FPBits_t::inf().get_val();
}
if (x_bits.is_nan()) {
return x;
}
if (y_bits.is_nan()) {
bool x_abs_larger = x_abs.uintval() >= y_abs.uintval();

FPBits_t a_bits = x_abs_larger ? x_abs : y_abs;
FPBits_t b_bits = x_abs_larger ? y_abs : x_abs;

if (LIBC_UNLIKELY(a_bits.is_inf_or_nan())) {
if (x_abs.is_signaling_nan() || y_abs.is_signaling_nan()) {
fputil::raise_except_if_required(FE_INVALID);
return FPBits_t::quiet_nan().get_val();
}
if (x_abs.is_inf() || y_abs.is_inf())
return FPBits_t::inf().get_val();
if (x_abs.is_nan())
return x;
// y is nan
return y;
}

uint16_t x_exp = x_bits.get_biased_exponent();
uint16_t y_exp = y_bits.get_biased_exponent();
uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp);
uint16_t a_exp = a_bits.get_biased_exponent();
uint16_t b_exp = b_bits.get_biased_exponent();

if ((exp_diff >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0)) {
return abs(x) + abs(y);
}
if ((a_exp - b_exp >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0))
return x_abs.get_val() + y_abs.get_val();

uint16_t a_exp, b_exp, out_exp;
StorageType a_mant, b_mant;
uint64_t out_exp = a_exp;
StorageType a_mant = a_bits.get_mantissa();
StorageType b_mant = b_bits.get_mantissa();
DStorageType a_mant_sq, b_mant_sq;
bool sticky_bits;

if (abs(x) >= abs(y)) {
a_exp = x_exp;
a_mant = x_bits.get_mantissa();
b_exp = y_exp;
b_mant = y_bits.get_mantissa();
} else {
a_exp = y_exp;
a_mant = y_bits.get_mantissa();
b_exp = x_exp;
b_mant = x_bits.get_mantissa();
}

out_exp = a_exp;

// Add an extra bit to simplify the final rounding bit computation.
constexpr StorageType ONE = StorageType(1) << (FPBits_t::FRACTION_LEN + 1);

Expand All @@ -165,11 +159,10 @@ LIBC_INLINE T hypot(T x, T y) {
a_exp = 1;
}

if (b_exp != 0) {
if (b_exp != 0)
b_mant |= ONE;
} else {
else
b_exp = 1;
}

a_mant_sq = static_cast<DStorageType>(a_mant) * a_mant;
b_mant_sq = static_cast<DStorageType>(b_mant) * b_mant;
Expand Down Expand Up @@ -260,6 +253,10 @@ LIBC_INLINE T hypot(T x, T y) {
}

y_new |= static_cast<StorageType>(out_exp) << FPBits_t::FRACTION_LEN;

if (!(round_bit || sticky_bits || (r != 0)))
fputil::clear_except_if_required(FE_INEXACT);

return cpp::bit_cast<T>(y_new);
}

Expand Down
5 changes: 4 additions & 1 deletion libc/src/math/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2809,9 +2809,12 @@ add_entrypoint_object(
HDRS
../hypotf.h
DEPENDS
libc.src.__support.FPUtil.basic_operations
libc.src.__support.FPUtil.double_double
libc.src.__support.FPUtil.fenv_impl
libc.src.__support.FPUtil.fp_bits
libc.src.__support.FPUtil.multiply_add
libc.src.__support.FPUtil.sqrt
libc.src.__support.macros.optimization
COMPILE_OPTIONS
-O3
)
Expand Down
100 changes: 62 additions & 38 deletions libc/src/math/generic/hypotf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,90 @@
//
//===----------------------------------------------------------------------===//
#include "src/math/hypotf.h"
#include "src/__support/FPUtil/BasicOperations.h"
#include "src/__support/FPUtil/FEnvImpl.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/double_double.h"
#include "src/__support/FPUtil/multiply_add.h"
#include "src/__support/FPUtil/sqrt.h"
#include "src/__support/common.h"
#include "src/__support/macros/config.h"
#include "src/__support/macros/optimization.h"

namespace LIBC_NAMESPACE_DECL {

LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
using DoubleBits = fputil::FPBits<double>;
using FPBits = fputil::FPBits<float>;

FPBits x_bits(x), y_bits(y);
FPBits x_abs = FPBits(x).abs();
FPBits y_abs = FPBits(y).abs();

uint16_t x_exp = x_bits.get_biased_exponent();
uint16_t y_exp = y_bits.get_biased_exponent();
uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp);
bool x_abs_larger = x_abs.uintval() >= y_abs.uintval();

if (exp_diff >= FPBits::FRACTION_LEN + 2) {
return fputil::abs(x) + fputil::abs(y);
}
FPBits a_bits = x_abs_larger ? x_abs : y_abs;
FPBits b_bits = x_abs_larger ? y_abs : x_abs;

double xd = static_cast<double>(x);
double yd = static_cast<double>(y);
uint32_t a_u = a_bits.uintval();
uint32_t b_u = b_bits.uintval();

// These squares are exact.
double x_sq = xd * xd;
double y_sq = yd * yd;
// Note: replacing `a_u >= FPBits::EXP_MASK` with `a_bits.is_inf_or_nan()`
// generates extra exponent bit masking instructions on x86-64.
if (LIBC_UNLIKELY(a_u >= FPBits::EXP_MASK)) {
// x or y is inf or nan
if (a_bits.is_signaling_nan() || b_bits.is_signaling_nan()) {
fputil::raise_except_if_required(FE_INVALID);
return FPBits::quiet_nan().get_val();
}
if (a_bits.is_inf() || b_bits.is_inf())
return FPBits::inf().get_val();
return a_bits.get_val();
}

// Compute the sum of squares.
double sum_sq = x_sq + y_sq;
if (LIBC_UNLIKELY(a_u - b_u >=
static_cast<uint32_t>((FPBits::FRACTION_LEN + 2)
<< FPBits::FRACTION_LEN)))
return x_abs.get_val() + y_abs.get_val();

// Compute the rounding error with Fast2Sum algorithm:
// x_sq + y_sq = sum_sq - err
double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;
double ad = static_cast<double>(a_bits.get_val());
double bd = static_cast<double>(b_bits.get_val());

// These squares are exact.
double a_sq = ad * ad;
#ifdef LIBC_TARGET_CPU_HAS_FMA
double sum_sq = fputil::multiply_add(bd, bd, a_sq);
#else
double b_sq = bd * bd;
double sum_sq = a_sq + b_sq;
#endif

// Take sqrt in double precision.
DoubleBits result(fputil::sqrt<double>(sum_sq));
uint64_t r_u = result.uintval();

if (!DoubleBits(sum_sq).is_inf_or_nan()) {
// Correct rounding.
double r_sq = result.get_val() * result.get_val();
double diff = sum_sq - r_sq;
constexpr uint64_t MASK = 0x0000'0000'3FFF'FFFFULL;
uint64_t lrs = result.uintval() & MASK;

if (lrs == 0x0000'0000'1000'0000ULL && err < diff) {
result.set_uintval(result.uintval() | 1ULL);
} else if (lrs == 0x0000'0000'3000'0000ULL && err > diff) {
result.set_uintval(result.uintval() - 1ULL);
}
} else {
FPBits bits_x(x), bits_y(y);
if (bits_x.is_inf_or_nan() || bits_y.is_inf_or_nan()) {
if (bits_x.is_inf() || bits_y.is_inf())
return FPBits::inf().get_val();
if (bits_x.is_nan())
return x;
return y;
// If any of the sticky bits of the result are non-zero, except the LSB, then
// the rounded result is correct.
if (LIBC_UNLIKELY(((r_u + 1) & 0x0000'0000'0FFF'FFFE) == 0)) {
double r_d = result.get_val();

// Perform rounding correction.
#ifdef LIBC_TARGET_CPU_HAS_FMA
double sum_sq_lo = fputil::multiply_add(bd, bd, a_sq - sum_sq);
double err = sum_sq_lo - fputil::multiply_add(r_d, r_d, -sum_sq);
#else
fputil::DoubleDouble r_sq = fputil::exact_mult(r_d, r_d);
double sum_sq_lo = b_sq - (sum_sq - a_sq);
double err = (sum_sq - r_sq.hi) + (sum_sq_lo - r_sq.lo);
#endif

if (err > 0) {
r_u |= 1;
} else if ((err < 0) && (r_u & 1) == 0) {
r_u -= 1;
} else if ((r_u & 0x0000'0000'1FFF'FFFF) == 0) {
// The rounded result is exact.
fputil::clear_except_if_required(FE_INEXACT);
}
return static_cast<float>(DoubleBits(r_u).get_val());
}

return static_cast<float>(result.get_val());
Expand Down
4 changes: 4 additions & 0 deletions libc/test/src/math/smoke/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2716,6 +2716,8 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
hypotf_test.cpp
HDRS
HypotTest.h
DEPENDS
libc.src.math.hypotf
libc.src.__support.FPUtil.fp_bits
Expand All @@ -2727,6 +2729,8 @@ add_fp_unittest(
libc-math-smoke-tests
SRCS
hypot_test.cpp
HDRS
HypotTest.h
DEPENDS
libc.src.math.hypot
libc.src.__support.FPUtil.fp_bits
Expand Down
31 changes: 9 additions & 22 deletions libc/test/src/math/smoke/HypotTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,29 @@
#ifndef LLVM_LIBC_TEST_SRC_MATH_HYPOTTEST_H
#define LLVM_LIBC_TEST_SRC_MATH_HYPOTTEST_H

#include "src/__support/FPUtil/FPBits.h"
#include "test/UnitTest/FEnvSafeTest.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"

#include "hdr/math_macros.h"

template <typename T>
class HypotTestTemplate : public LIBC_NAMESPACE::testing::FEnvSafeTest {
class HypotTestTemplate : public LIBC_NAMESPACE::testing::Test {
private:
using Func = T (*)(T, T);
using FPBits = LIBC_NAMESPACE::fputil::FPBits<T>;
using StorageType = typename FPBits::StorageType;

const T nan = FPBits::quiet_nan().get_val();
const T inf = FPBits::inf(Sign::POS).get_val();
const T neg_inf = FPBits::inf(Sign::NEG).get_val();
const T zero = FPBits::zero(Sign::POS).get_val();
const T neg_zero = FPBits::zero(Sign::NEG).get_val();

const T max_normal = FPBits::max_normal().get_val();
const T min_normal = FPBits::min_normal().get_val();
const T max_subnormal = FPBits::max_subnormal().get_val();
const T min_subnormal = FPBits::min_subnormal().get_val();
DECLARE_SPECIAL_CONSTANTS(T)

public:
void test_special_numbers(Func func) {
constexpr int N = 4;
// Pythagorean triples.
constexpr T PYT[N][3] = {{3, 4, 5}, {5, 12, 13}, {8, 15, 17}, {7, 24, 25}};

EXPECT_FP_EQ(func(inf, nan), inf);
EXPECT_FP_EQ(func(nan, neg_inf), inf);
EXPECT_FP_EQ(func(nan, nan), nan);
EXPECT_FP_EQ(func(nan, zero), nan);
EXPECT_FP_EQ(func(neg_zero, nan), nan);
EXPECT_FP_EQ(func(inf, sNaN), aNaN);
EXPECT_FP_EQ(func(sNaN, neg_inf), aNaN);
EXPECT_FP_EQ(func(inf, aNaN), inf);
EXPECT_FP_EQ(func(aNaN, neg_inf), inf);
EXPECT_FP_EQ(func(aNaN, aNaN), aNaN);
EXPECT_FP_EQ(func(aNaN, zero), aNaN);
EXPECT_FP_EQ(func(neg_zero, aNaN), aNaN);
Comment on lines +28 to +34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EXPECT_FP_EQ takes the expected value as first argument, and the actual value as second argument, so the arguments should be swapped here. Maybe this is one of the things we should someday change across the whole libc in a single commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, we should make a sweeping change for this at some point.


for (int i = 0; i < N; ++i) {
EXPECT_FP_EQ_ALL_ROUNDING(PYT[i][2], func(PYT[i][0], PYT[i][1]));
Expand Down
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/libc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,7 @@ libc_math_function(name = "hypot")
libc_math_function(
name = "hypotf",
additional_deps = [
":__support_fputil_double_double",
":__support_fputil_sqrt",
],
)
Expand Down
Loading