Skip to content

[libc][math] Fix signaling nan handling of hypot(f) and improve hypotf performance. #99432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 19, 2024

Conversation

lntue
Copy link
Contributor

@lntue lntue commented Jul 18, 2024

The errors were reported by Paul Zimmermann with the CORE-MATH project's test suites:

zimmerma@tartine:/tmp/core-math$ CORE_MATH_CHECK_STD=true LIBM=$L ./check.sh hypot
Running worst cases check in --rndn mode...
FAIL x=snan y=inf ref=qnan z=inf
Running worst cases check in --rndz mode...
FAIL x=snan y=inf ref=qnan z=inf
Running worst cases check in --rndu mode...
FAIL x=snan y=inf ref=qnan z=inf
Running worst cases check in --rndd mode...
Spurious inexact exception for x=0x1.ffffffffffffep+24 y=0x1p+0 (z=0x1.0000000000001p+25)

@lntue lntue requested review from rupprecht and keith as code owners July 18, 2024 05:26
@lntue
Copy link
Contributor Author

lntue commented Jul 18, 2024

@zimmermann6

@llvmbot llvmbot added libc bazel "Peripheral" support tier build system: utils/bazel labels Jul 18, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-libc

Author: None (lntue)

Changes

The errors were reported by Paul Zimmermann with the CORE-MATH project's test suites:

zimmerma@<!-- -->tartine:/tmp/core-math$ CORE_MATH_CHECK_STD=true LIBM=$L ./check.sh hypot
Running worst cases check in --rndn mode...
FAIL x=snan y=inf ref=qnan z=inf
Running worst cases check in --rndz mode...
FAIL x=snan y=inf ref=qnan z=inf
Running worst cases check in --rndu mode...
FAIL x=snan y=inf ref=qnan z=inf
Running worst cases check in --rndd mode...
Spurious inexact exception for x=0x1.ffffffffffffep+24 y=0x1p+0 (z=0x1.0000000000001p+25)

Full diff: https://github.com/llvm/llvm-project/pull/99432.diff

5 Files Affected:

  • (modified) libc/src/__support/FPUtil/Hypot.h (+30-33)
  • (modified) libc/src/math/generic/CMakeLists.txt (+4-1)
  • (modified) libc/src/math/generic/hypotf.cpp (+60-38)
  • (modified) libc/test/src/math/smoke/HypotTest.h (+9-18)
  • (modified) utils/bazel/llvm-project-overlay/libc/BUILD.bazel (+1)
diff --git a/libc/src/__support/FPUtil/Hypot.h b/libc/src/__support/FPUtil/Hypot.h
index a5a991e75ed01..6aa808446d6d9 100644
--- a/libc/src/__support/FPUtil/Hypot.h
+++ b/libc/src/__support/FPUtil/Hypot.h
@@ -109,45 +109,39 @@ LIBC_INLINE T hypot(T x, T y) {
   using StorageType = typename FPBits<T>::StorageType;
   using DStorageType = typename DoubleLength<StorageType>::Type;
 
-  FPBits_t x_bits(x), y_bits(y);
+  FPBits_t x_abs = FPBits_t(x).abs();
+  FPBits_t y_abs = FPBits_t(y).abs();
 
-  if (x_bits.is_inf() || y_bits.is_inf()) {
-    return FPBits_t::inf().get_val();
-  }
-  if (x_bits.is_nan()) {
-    return x;
-  }
-  if (y_bits.is_nan()) {
+  bool x_abs_larger = x_abs.uintval() >= y_abs.uintval();
+
+  FPBits_t a_bits = x_abs_larger ? x_abs : y_abs;
+  FPBits_t b_bits = x_abs_larger ? y_abs : x_abs;
+
+  if (LIBC_UNLIKELY(a_bits.is_inf_or_nan())) {
+    if (x_abs.is_signaling_nan() || y_abs.is_signaling_nan()) {
+      fputil::raise_except_if_required(FE_INVALID);
+      return FPBits_t::quiet_nan().get_val();
+    }
+    if (x_abs.is_inf() || y_abs.is_inf())
+      return FPBits_t::inf().get_val();
+    if (x_abs.is_nan())
+      return x;
+    // y is nan
     return y;
   }
 
-  uint16_t x_exp = x_bits.get_biased_exponent();
-  uint16_t y_exp = y_bits.get_biased_exponent();
-  uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp);
+  uint16_t a_exp = a_bits.get_biased_exponent();
+  uint16_t b_exp = b_bits.get_biased_exponent();
 
-  if ((exp_diff >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0)) {
-    return abs(x) + abs(y);
-  }
+  if ((a_exp - b_exp >= FPBits_t::FRACTION_LEN + 2) || (x == 0) || (y == 0))
+    return x_abs.get_val() + y_abs.get_val();
 
-  uint16_t a_exp, b_exp, out_exp;
-  StorageType a_mant, b_mant;
+  uint64_t out_exp = a_exp;
+  StorageType a_mant = a_bits.get_mantissa();
+  StorageType b_mant = b_bits.get_mantissa();
   DStorageType a_mant_sq, b_mant_sq;
   bool sticky_bits;
 
-  if (abs(x) >= abs(y)) {
-    a_exp = x_exp;
-    a_mant = x_bits.get_mantissa();
-    b_exp = y_exp;
-    b_mant = y_bits.get_mantissa();
-  } else {
-    a_exp = y_exp;
-    a_mant = y_bits.get_mantissa();
-    b_exp = x_exp;
-    b_mant = x_bits.get_mantissa();
-  }
-
-  out_exp = a_exp;
-
   // Add an extra bit to simplify the final rounding bit computation.
   constexpr StorageType ONE = StorageType(1) << (FPBits_t::FRACTION_LEN + 1);
 
@@ -165,11 +159,10 @@ LIBC_INLINE T hypot(T x, T y) {
     a_exp = 1;
   }
 
-  if (b_exp != 0) {
+  if (b_exp != 0)
     b_mant |= ONE;
-  } else {
+  else
     b_exp = 1;
-  }
 
   a_mant_sq = static_cast<DStorageType>(a_mant) * a_mant;
   b_mant_sq = static_cast<DStorageType>(b_mant) * b_mant;
@@ -260,6 +253,10 @@ LIBC_INLINE T hypot(T x, T y) {
   }
 
   y_new |= static_cast<StorageType>(out_exp) << FPBits_t::FRACTION_LEN;
+
+  if (!(round_bit || sticky_bits || (r != 0)))
+    fputil::clear_except_if_required(FE_INEXACT);
+
   return cpp::bit_cast<T>(y_new);
 }
 
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index 415ca3fbce796..413ece04e9421 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -2809,9 +2809,12 @@ add_entrypoint_object(
   HDRS
     ../hypotf.h
   DEPENDS
-    libc.src.__support.FPUtil.basic_operations
+    libc.src.__support.FPUtil.double_double
+    libc.src.__support.FPUtil.fenv_impl
     libc.src.__support.FPUtil.fp_bits
+    libc.src.__support.FPUtil.multiply_add
     libc.src.__support.FPUtil.sqrt
+    libc.src.__support.macros.optimization
   COMPILE_OPTIONS
     -O3
 )
diff --git a/libc/src/math/generic/hypotf.cpp b/libc/src/math/generic/hypotf.cpp
index 75c55ed846726..478209a14f454 100644
--- a/libc/src/math/generic/hypotf.cpp
+++ b/libc/src/math/generic/hypotf.cpp
@@ -6,11 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 #include "src/math/hypotf.h"
-#include "src/__support/FPUtil/BasicOperations.h"
+#include "src/__support/FPUtil/FEnvImpl.h"
 #include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/FPUtil/double_double.h"
+#include "src/__support/FPUtil/multiply_add.h"
 #include "src/__support/FPUtil/sqrt.h"
 #include "src/__support/common.h"
 #include "src/__support/macros/config.h"
+#include "src/__support/macros/optimization.h"
 
 namespace LIBC_NAMESPACE_DECL {
 
@@ -18,54 +21,73 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
   using DoubleBits = fputil::FPBits<double>;
   using FPBits = fputil::FPBits<float>;
 
-  FPBits x_bits(x), y_bits(y);
+  FPBits x_abs = FPBits(x).abs();
+  FPBits y_abs = FPBits(y).abs();
 
-  uint16_t x_exp = x_bits.get_biased_exponent();
-  uint16_t y_exp = y_bits.get_biased_exponent();
-  uint16_t exp_diff = (x_exp > y_exp) ? (x_exp - y_exp) : (y_exp - x_exp);
+  bool x_abs_larger = x_abs.uintval() >= y_abs.uintval();
 
-  if (exp_diff >= FPBits::FRACTION_LEN + 2) {
-    return fputil::abs(x) + fputil::abs(y);
-  }
+  FPBits a_bits = x_abs_larger ? x_abs : y_abs;
+  FPBits b_bits = x_abs_larger ? y_abs : x_abs;
 
-  double xd = static_cast<double>(x);
-  double yd = static_cast<double>(y);
+  uint32_t a_u = a_bits.uintval();
+  uint32_t b_u = b_bits.uintval();
 
-  // These squares are exact.
-  double x_sq = xd * xd;
-  double y_sq = yd * yd;
+  if (LIBC_UNLIKELY(a_u >= FPBits::EXP_MASK)) {
+    // x or y is inf or nan
+    if (a_bits.is_signaling_nan() || b_bits.is_signaling_nan()) {
+      fputil::raise_except_if_required(FE_INVALID);
+      return FPBits::quiet_nan().get_val();
+    }
+    if (a_bits.is_inf() || b_bits.is_inf())
+      return FPBits::inf().get_val();
+    return a_bits.get_val();
+  }
 
-  // Compute the sum of squares.
-  double sum_sq = x_sq + y_sq;
+  if (LIBC_UNLIKELY(a_u - b_u >=
+                    static_cast<uint32_t>((FPBits::FRACTION_LEN + 2)
+                                          << FPBits::FRACTION_LEN)))
+    return x_abs.get_val() + y_abs.get_val();
 
-  // Compute the rounding error with Fast2Sum algorithm:
-  // x_sq + y_sq = sum_sq - err
-  double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;
+  double ad = static_cast<double>(a_bits.get_val());
+  double bd = static_cast<double>(b_bits.get_val());
+
+  // These squares are exact.
+  double a_sq = ad * ad;
+#ifdef LIBC_TARGET_CPU_HAS_FMA
+  double sum_sq = fputil::multiply_add(bd, bd, a_sq);
+#else
+  double b_sq = bd * bd;
+  double sum_sq = a_sq + b_sq;
+#endif
 
   // Take sqrt in double precision.
   DoubleBits result(fputil::sqrt<double>(sum_sq));
+  uint64_t r_u = result.uintval();
 
-  if (!DoubleBits(sum_sq).is_inf_or_nan()) {
-    // Correct rounding.
-    double r_sq = result.get_val() * result.get_val();
-    double diff = sum_sq - r_sq;
-    constexpr uint64_t MASK = 0x0000'0000'3FFF'FFFFULL;
-    uint64_t lrs = result.uintval() & MASK;
-
-    if (lrs == 0x0000'0000'1000'0000ULL && err < diff) {
-      result.set_uintval(result.uintval() | 1ULL);
-    } else if (lrs == 0x0000'0000'3000'0000ULL && err > diff) {
-      result.set_uintval(result.uintval() - 1ULL);
-    }
-  } else {
-    FPBits bits_x(x), bits_y(y);
-    if (bits_x.is_inf_or_nan() || bits_y.is_inf_or_nan()) {
-      if (bits_x.is_inf() || bits_y.is_inf())
-        return FPBits::inf().get_val();
-      if (bits_x.is_nan())
-        return x;
-      return y;
+  // If any of the sticky bits of the result are non-zero, except the LSB, then
+  // the rounded result is correct.
+  if (LIBC_UNLIKELY(((r_u + 1) & 0x0000'0000'0FFF'FFFE) == 0)) {
+    double r_d = result.get_val();
+
+    // Perform rounding correction.
+#ifdef LIBC_TARGET_CPU_HAS_FMA
+    double sum_sq_lo = fputil::multiply_add(bd, bd, a_sq - sum_sq);
+    double err = sum_sq_lo - fputil::multiply_add(r_d, r_d, -sum_sq);
+#else
+    fputil::DoubleDouble r_sq = fputil::exact_mult(r_d, r_d);
+    double sum_sq_lo = b_sq - (sum_sq - a_sq);
+    double err = (sum_sq - r_sq.hi) + (sum_sq_lo - r_sq.lo);
+#endif
+
+    if (err > 0)
+      r_u |= 1;
+    else if ((err < 0) && (r_u & 1) == 0)
+      r_u -= 1;
+    else if ((r_u & 0x0000'0000'1FFF'FFFF) == 0) {
+      // The rounded result is exact.
+      fputil::clear_except_if_required(FE_INEXACT);
     }
+    return static_cast<float>(DoubleBits(r_u).get_val());
   }
 
   return static_cast<float>(result.get_val());
diff --git a/libc/test/src/math/smoke/HypotTest.h b/libc/test/src/math/smoke/HypotTest.h
index 80e9bb7366dfe..6efc3c8509c7f 100644
--- a/libc/test/src/math/smoke/HypotTest.h
+++ b/libc/test/src/math/smoke/HypotTest.h
@@ -17,22 +17,11 @@
 #include "hdr/math_macros.h"
 
 template <typename T>
-class HypotTestTemplate : public LIBC_NAMESPACE::testing::FEnvSafeTest {
+class HypotTestTemplate : public LIBC_NAMESPACE::testing::Test {
 private:
   using Func = T (*)(T, T);
-  using FPBits = LIBC_NAMESPACE::fputil::FPBits<T>;
-  using StorageType = typename FPBits::StorageType;
 
-  const T nan = FPBits::quiet_nan().get_val();
-  const T inf = FPBits::inf(Sign::POS).get_val();
-  const T neg_inf = FPBits::inf(Sign::NEG).get_val();
-  const T zero = FPBits::zero(Sign::POS).get_val();
-  const T neg_zero = FPBits::zero(Sign::NEG).get_val();
-
-  const T max_normal = FPBits::max_normal().get_val();
-  const T min_normal = FPBits::min_normal().get_val();
-  const T max_subnormal = FPBits::max_subnormal().get_val();
-  const T min_subnormal = FPBits::min_subnormal().get_val();
+  DECLARE_SPECIAL_CONSTANTS(T)
 
 public:
   void test_special_numbers(Func func) {
@@ -40,11 +29,13 @@ class HypotTestTemplate : public LIBC_NAMESPACE::testing::FEnvSafeTest {
     // Pythagorean triples.
     constexpr T PYT[N][3] = {{3, 4, 5}, {5, 12, 13}, {8, 15, 17}, {7, 24, 25}};
 
-    EXPECT_FP_EQ(func(inf, nan), inf);
-    EXPECT_FP_EQ(func(nan, neg_inf), inf);
-    EXPECT_FP_EQ(func(nan, nan), nan);
-    EXPECT_FP_EQ(func(nan, zero), nan);
-    EXPECT_FP_EQ(func(neg_zero, nan), nan);
+    EXPECT_FP_EQ(func(inf, sNaN), aNaN);
+    EXPECT_FP_EQ(func(sNaN, neg_inf), aNaN);
+    EXPECT_FP_EQ(func(inf, aNaN), inf);
+    EXPECT_FP_EQ(func(aNaN, neg_inf), inf);
+    EXPECT_FP_EQ(func(aNaN, aNaN), aNaN);
+    EXPECT_FP_EQ(func(aNaN, zero), aNaN);
+    EXPECT_FP_EQ(func(neg_zero, aNaN), aNaN);
 
     for (int i = 0; i < N; ++i) {
       EXPECT_FP_EQ_ALL_ROUNDING(PYT[i][2], func(PYT[i][0], PYT[i][1]));
diff --git a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
index f0c25e658fd18..63fd3faf38aef 100644
--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
@@ -2020,6 +2020,7 @@ libc_math_function(name = "hypot")
 libc_math_function(
     name = "hypotf",
     additional_deps = [
+        ":__support_fputil_double_double",
         ":__support_fputil_sqrt",
     ],
 )

Comment on lines +32 to +38
EXPECT_FP_EQ(func(inf, sNaN), aNaN);
EXPECT_FP_EQ(func(sNaN, neg_inf), aNaN);
EXPECT_FP_EQ(func(inf, aNaN), inf);
EXPECT_FP_EQ(func(aNaN, neg_inf), inf);
EXPECT_FP_EQ(func(aNaN, aNaN), aNaN);
EXPECT_FP_EQ(func(aNaN, zero), aNaN);
EXPECT_FP_EQ(func(neg_zero, aNaN), aNaN);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EXPECT_FP_EQ takes the expected value as first argument, and the actual value as second argument, so the arguments should be swapped here. Maybe this is one of the things we should someday change across the whole libc in a single commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, we should make a sweeping change for this at some point.

Copy link
Member

@overmighty overmighty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, these includes in libc/test/src/math/smoke/HypotTest.h should probably be removed:

  • "src/__support/FPUtil/FPBits.h"
  • "test/UnitTest/FEnvSafeTest.h"
  • "hdr/math_macros.h"

@lntue
Copy link
Contributor Author

lntue commented Jul 19, 2024

Oh, these includes in libc/test/src/math/smoke/HypotTest.h should probably be removed:

  • "src/__support/FPUtil/FPBits.h"
  • "test/UnitTest/FEnvSafeTest.h"
  • "hdr/math_macros.h"

Done.

@lntue lntue merged commit 9da9127 into llvm:main Jul 19, 2024
6 checks passed
@lntue lntue deleted the hypot branch July 19, 2024 14:40
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…f performance. (#99432)

Summary:
The errors were reported by Paul Zimmermann with the CORE-MATH project's
test suites:
```
zimmerma@tartine:/tmp/core-math$ CORE_MATH_CHECK_STD=true LIBM=$L ./check.sh hypot
Running worst cases check in --rndn mode...
FAIL x=snan y=inf ref=qnan z=inf
Running worst cases check in --rndz mode...
FAIL x=snan y=inf ref=qnan z=inf
Running worst cases check in --rndu mode...
FAIL x=snan y=inf ref=qnan z=inf
Running worst cases check in --rndd mode...
Spurious inexact exception for x=0x1.ffffffffffffep+24 y=0x1p+0 (z=0x1.0000000000001p+25)
```

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251275
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel libc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants