Skip to content

[libc][math][c23] Add f16sqrtf C23 math function #95251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libc/config/linux/aarch64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
libc.src.math.canonicalizef16
libc.src.math.ceilf16
libc.src.math.copysignf16
libc.src.math.f16sqrtf
libc.src.math.fabsf16
libc.src.math.fdimf16
libc.src.math.floorf16
Expand Down
1 change: 1 addition & 0 deletions libc/config/linux/x86_64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
libc.src.math.canonicalizef16
libc.src.math.ceilf16
libc.src.math.copysignf16
libc.src.math.f16sqrtf
libc.src.math.fabsf16
libc.src.math.fdimf16
libc.src.math.floorf16
Expand Down
2 changes: 2 additions & 0 deletions libc/docs/math/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ Higher Math Functions
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| fma | |check| | |check| | | | | 7.12.13.1 | F.10.10.1 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| f16sqrt | |check| | | | N/A | | 7.12.14.6 | F.10.11 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| fsqrt | N/A | | | N/A | | 7.12.14.6 | F.10.11 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| hypot | |check| | |check| | | | | 7.12.7.4 | F.10.4.4 |
Expand Down
2 changes: 2 additions & 0 deletions libc/spec/stdc.td
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,8 @@ def StdC : StandardSpec<"stdc"> {
GuardedFunctionSpec<"totalorderf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,

GuardedFunctionSpec<"totalordermagf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,

GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
]
>;

Expand Down
1 change: 1 addition & 0 deletions libc/src/__support/FPUtil/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_header_library(
sqrt.h
sqrt_80_bit_long_double.h
DEPENDS
libc.hdr.fenv_macros
libc.src.__support.common
libc.src.__support.CPP.bit
libc.src.__support.CPP.type_traits
Expand Down
128 changes: 99 additions & 29 deletions libc/src/__support/FPUtil/generic/sqrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "src/__support/common.h"
#include "src/__support/uint128.h"

#include "hdr/fenv_macros.h"

namespace LIBC_NAMESPACE {
namespace fputil {

Expand Down Expand Up @@ -64,40 +66,50 @@ LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {

// Correctly rounded IEEE 754 SQRT for all rounding modes.
// Shift-and-add algorithm.
template <typename T>
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {

if constexpr (internal::SpecialLongDouble<T>::VALUE) {
template <typename OutType, typename InType>
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
cpp::is_floating_point_v<InType> &&
sizeof(OutType) <= sizeof(InType),
OutType>
sqrt(InType x) {
if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
internal::SpecialLongDouble<InType>::VALUE) {
// Special 80-bit long double.
return x86::sqrt(x);
} else {
// IEEE floating points formats.
using FPBits_t = typename fputil::FPBits<T>;
using StorageType = typename FPBits_t::StorageType;
constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN;
constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val();

FPBits_t bits(x);

if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
using OutFPBits = typename fputil::FPBits<OutType>;
using OutStorageType = typename OutFPBits::StorageType;
using InFPBits = typename fputil::FPBits<InType>;
using InStorageType = typename InFPBits::StorageType;
constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
constexpr int EXTRA_FRACTION_LEN =
InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
constexpr InStorageType EXTRA_FRACTION_MASK =
(InStorageType(1) << EXTRA_FRACTION_LEN) - 1;

InFPBits bits(x);

if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
// sqrt(+Inf) = +Inf
// sqrt(+0) = +0
// sqrt(-0) = -0
// sqrt(NaN) = NaN
// sqrt(-NaN) = -NaN
return x;
return static_cast<OutType>(x);
} else if (bits.is_neg()) {
// sqrt(-Inf) = NaN
// sqrt(-x) = NaN
return FLT_NAN;
} else {
int x_exp = bits.get_exponent();
StorageType x_mant = bits.get_mantissa();
InStorageType x_mant = bits.get_mantissa();

// Step 1a: Normalize denormal input and append hidden bit to the mantissa
if (bits.is_subnormal()) {
++x_exp; // let x_exp be the correct exponent of ONE bit.
internal::normalize<T>(x_exp, x_mant);
internal::normalize<InType>(x_exp, x_mant);
} else {
x_mant |= ONE;
}
Expand All @@ -120,47 +132,105 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
// So the nth digit y_n of the mantissa of sqrt(x) can be found by:
// y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
// 0 otherwise.
StorageType y = ONE;
StorageType r = x_mant - ONE;
InStorageType y = ONE;
InStorageType r = x_mant - ONE;

for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
for (InStorageType current_bit = ONE >> 1; current_bit;
current_bit >>= 1) {
r <<= 1;
StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
InStorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
if (r >= tmp) {
r -= tmp;
y += current_bit;
}
}

// We compute one more iteration in order to round correctly.
bool lsb = static_cast<bool>(y & 1); // Least significant bit
bool rb = false; // Round bit
bool lsb = (y & (InStorageType(1) << EXTRA_FRACTION_LEN)) !=
0; // Least significant bit
bool rb = false; // Round bit
r <<= 2;
StorageType tmp = (y << 2) + 1;
InStorageType tmp = (y << 2) + 1;
if (r >= tmp) {
r -= tmp;
rb = true;
}

bool sticky = false;

if constexpr (EXTRA_FRACTION_LEN > 0) {
sticky = rb || (y & EXTRA_FRACTION_MASK) != 0;
rb = (y & (InStorageType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
}

// Remove hidden bit and append the exponent field.
x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS);
x_exp = ((x_exp >> 1) + OutFPBits::EXP_BIAS);

OutStorageType y_out = static_cast<OutStorageType>(
((y - ONE) >> EXTRA_FRACTION_LEN) |
(static_cast<OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));

if constexpr (EXTRA_FRACTION_LEN > 0) {
if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
switch (quick_get_round()) {
case FE_TONEAREST:
case FE_UPWARD:
return OutFPBits::inf().get_val();
default:
return OutFPBits::max_normal().get_val();
}
}

if (x_exp <
-OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
switch (quick_get_round()) {
case FE_UPWARD:
return OutFPBits::min_subnormal().get_val();
default:
return OutType(0.0);
}
}

y = (y - ONE) |
(static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
if (x_exp <= 0) {
int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1;
InStorageType underflow_extra_fraction_mask =
(InStorageType(1) << underflow_extra_fraction_len) - 1;

rb = (y & (InStorageType(1) << (underflow_extra_fraction_len - 1))) !=
0;
OutStorageType subnormal_mant =
static_cast<OutStorageType>(y >> underflow_extra_fraction_len);
lsb = (subnormal_mant & 1) != 0;
sticky = sticky || (y & underflow_extra_fraction_mask) != 0;

switch (quick_get_round()) {
case FE_TONEAREST:
if (rb && (lsb || sticky))
++subnormal_mant;
break;
case FE_UPWARD:
if (rb || sticky)
++subnormal_mant;
break;
}

return cpp::bit_cast<OutType>(subnormal_mant);
}
}

switch (quick_get_round()) {
case FE_TONEAREST:
// Round to nearest, ties to even
if (rb && (lsb || (r != 0)))
++y;
++y_out;
break;
case FE_UPWARD:
if (rb || (r != 0))
++y;
if (rb || (r != 0) || sticky)
++y_out;
break;
}

return cpp::bit_cast<T>(y);
return cpp::bit_cast<OutType>(y_out);
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions libc/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ add_math_entrypoint_object(exp10f)
add_math_entrypoint_object(expm1)
add_math_entrypoint_object(expm1f)

add_math_entrypoint_object(f16sqrtf)

add_math_entrypoint_object(fabs)
add_math_entrypoint_object(fabsf)
add_math_entrypoint_object(fabsl)
Expand Down
20 changes: 20 additions & 0 deletions libc/src/math/f16sqrtf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===-- Implementation header for f16sqrtf ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_MATH_F16SQRTF_H
#define LLVM_LIBC_SRC_MATH_F16SQRTF_H

#include "src/__support/macros/properties/types.h"

namespace LIBC_NAMESPACE {

float16 f16sqrtf(float x);

} // namespace LIBC_NAMESPACE

#endif // LLVM_LIBC_SRC_MATH_F16SQRTF_H
13 changes: 13 additions & 0 deletions libc/src/math/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3601,3 +3601,16 @@ add_entrypoint_object(
COMPILE_OPTIONS
-O3
)

add_entrypoint_object(
f16sqrtf
SRCS
f16sqrtf.cpp
HDRS
../f16sqrtf.h
DEPENDS
libc.src.__support.macros.properties.types
libc.src.__support.FPUtil.sqrt
COMPILE_OPTIONS
-O3
)
2 changes: 1 addition & 1 deletion libc/src/math/generic/acosf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ LLVM_LIBC_FUNCTION(float, acosf, (float x)) {
xbits.set_sign(Sign::POS);
double xd = static_cast<double>(xbits.get_val());
double u = fputil::multiply_add(-0.5, xd, 0.5);
double cv = 2 * fputil::sqrt(u);
double cv = 2 * fputil::sqrt<double>(u);

double r3 = asin_eval(u);
double r = fputil::multiply_add(cv * u, r3, cv);
Expand Down
4 changes: 2 additions & 2 deletions libc/src/math/generic/acoshf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ LLVM_LIBC_FUNCTION(float, acoshf, (float x)) {

double x_d = static_cast<double>(x);
// acosh(x) = log(x + sqrt(x^2 - 1))
return static_cast<float>(
log_eval(x_d + fputil::sqrt(fputil::multiply_add(x_d, x_d, -1.0))));
return static_cast<float>(log_eval(
x_d + fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, -1.0))));
}

} // namespace LIBC_NAMESPACE
2 changes: 1 addition & 1 deletion libc/src/math/generic/asinf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ LLVM_LIBC_FUNCTION(float, asinf, (float x)) {
double sign = SIGN[x_sign];
double xd = static_cast<double>(xbits.get_val());
double u = fputil::multiply_add(-0.5, xd, 0.5);
double c1 = sign * (-2 * fputil::sqrt(u));
double c1 = sign * (-2 * fputil::sqrt<double>(u));
double c2 = fputil::multiply_add(sign, M_MATH_PI_2, c1);
double c3 = c1 * u;

Expand Down
6 changes: 3 additions & 3 deletions libc/src/math/generic/asinhf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ LLVM_LIBC_FUNCTION(float, asinhf, (float x)) {

// asinh(x) = log(x + sqrt(x^2 + 1))
return static_cast<float>(
x_sign *
log_eval(fputil::multiply_add(
x_d, x_sign, fputil::sqrt(fputil::multiply_add(x_d, x_d, 1.0)))));
x_sign * log_eval(fputil::multiply_add(
x_d, x_sign,
fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, 1.0)))));
}

} // namespace LIBC_NAMESPACE
19 changes: 19 additions & 0 deletions libc/src/math/generic/f16sqrtf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- Implementation of f16sqrtf function -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/f16sqrtf.h"
#include "src/__support/FPUtil/sqrt.h"
#include "src/__support/common.h"

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(float16, f16sqrtf, (float x)) {
return fputil::sqrt<float16>(x);
}

} // namespace LIBC_NAMESPACE
2 changes: 1 addition & 1 deletion libc/src/math/generic/hypotf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;

// Take sqrt in double precision.
DoubleBits result(fputil::sqrt(sum_sq));
DoubleBits result(fputil::sqrt<double>(sum_sq));

if (!DoubleBits(sum_sq).is_inf_or_nan()) {
// Correct rounding.
Expand Down
2 changes: 1 addition & 1 deletion libc/src/math/generic/powf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ LLVM_LIBC_FUNCTION(float, powf, (float x, float y)) {
switch (y_u) {
case 0x3f00'0000: // y = 0.5f
// pow(x, 1/2) = sqrt(x)
return fputil::sqrt(x);
return fputil::sqrt<float>(x);
case 0x3f80'0000: // y = 1.0f
return x;
case 0x4000'0000: // y = 2.0f
Expand Down
2 changes: 1 addition & 1 deletion libc/src/math/generic/sqrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); }
LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt<double>(x); }

} // namespace LIBC_NAMESPACE
2 changes: 1 addition & 1 deletion libc/src/math/generic/sqrtf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt(x); }
LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt<float>(x); }

} // namespace LIBC_NAMESPACE
4 changes: 3 additions & 1 deletion libc/src/math/generic/sqrtf128.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) { return fputil::sqrt(x); }
LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
return fputil::sqrt<float128>(x);
}

} // namespace LIBC_NAMESPACE
2 changes: 1 addition & 1 deletion libc/src/math/generic/sqrtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) {
return fputil::sqrt(x);
return fputil::sqrt<long double>(x);
}

} // namespace LIBC_NAMESPACE
Loading
Loading