Skip to content

[libc][stdfix] Add sqrt for fixed point types. #83042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions libc/config/linux/x86_64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,12 @@ if(LIBC_COMPILER_HAS_FIXED_POINT)
libc.src.stdfix.roundur
libc.src.stdfix.roundulk
libc.src.stdfix.roundulr
libc.src.stdfix.sqrtuhk
libc.src.stdfix.sqrtuhr
libc.src.stdfix.sqrtuk
libc.src.stdfix.sqrtur
# libc.src.stdfix.sqrtulk
libc.src.stdfix.sqrtulr
)
endif()

Expand Down
2 changes: 1 addition & 1 deletion libc/docs/math/stdfix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Fixed-point Arithmetics
+---------------+----------------+-------------+---------------+------------+----------------+-------------+----------------+-------------+---------------+------------+----------------+-------------+
| round | |check| | |check| | |check| | |check| | |check| | |check| | |check| | |check| | |check| | |check| | |check| | |check| |
+---------------+----------------+-------------+---------------+------------+----------------+-------------+----------------+-------------+---------------+------------+----------------+-------------+
| sqrt | | | | | | | | | | | | |
| sqrt | |check| | | |check| | | |check| | | |check| | | |check| | | | |
+---------------+----------------+-------------+---------------+------------+----------------+-------------+----------------+-------------+---------------+------------+----------------+-------------+

================== =========
Expand Down
8 changes: 8 additions & 0 deletions libc/spec/stdc_ext.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def StdcExt : StandardSpec<"stdc_ext"> {
GuardedFunctionSpec<"rounduhk", RetValSpec<UnsignedShortAccumType>, [ArgSpec<UnsignedShortAccumType>, ArgSpec<IntType>], "LIBC_COMPILER_HAS_FIXED_POINT">,
GuardedFunctionSpec<"rounduk", RetValSpec<UnsignedAccumType>, [ArgSpec<UnsignedAccumType>, ArgSpec<IntType>], "LIBC_COMPILER_HAS_FIXED_POINT">,
GuardedFunctionSpec<"roundulk", RetValSpec<UnsignedLongAccumType>, [ArgSpec<UnsignedLongAccumType>, ArgSpec<IntType>], "LIBC_COMPILER_HAS_FIXED_POINT">,

GuardedFunctionSpec<"sqrtuhr", RetValSpec<UnsignedShortFractType>, [ArgSpec<UnsignedShortFractType>], "LIBC_COMPILER_HAS_FIXED_POINT">,
GuardedFunctionSpec<"sqrtur", RetValSpec<UnsignedFractType>, [ArgSpec<UnsignedFractType>], "LIBC_COMPILER_HAS_FIXED_POINT">,
GuardedFunctionSpec<"sqrtulr", RetValSpec<UnsignedLongFractType>, [ArgSpec<UnsignedLongFractType>], "LIBC_COMPILER_HAS_FIXED_POINT">,

GuardedFunctionSpec<"sqrtuhk", RetValSpec<UnsignedShortAccumType>, [ArgSpec<UnsignedShortAccumType>], "LIBC_COMPILER_HAS_FIXED_POINT">,
GuardedFunctionSpec<"sqrtuk", RetValSpec<UnsignedAccumType>, [ArgSpec<UnsignedAccumType>], "LIBC_COMPILER_HAS_FIXED_POINT">,
GuardedFunctionSpec<"sqrtulk", RetValSpec<UnsignedLongAccumType>, [ArgSpec<UnsignedLongAccumType>], "LIBC_COMPILER_HAS_FIXED_POINT">,
]
>;

Expand Down
13 changes: 13 additions & 0 deletions libc/src/__support/fixed_point/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,16 @@ add_header_library(
libc.src.__support.CPP.bit
libc.src.__support.math_extras
)

add_header_library(
sqrt
HDRS
sqrt.h
DEPENDS
.fx_rep
libc.include.llvm-libc-macros.stdfix_macros
libc.src.__support.macros.attributes
libc.src.__support.macros.optimization
libc.src.__support.CPP.bit
libc.src.__support.CPP.type_traits
)
129 changes: 129 additions & 0 deletions libc/src/__support/fixed_point/sqrt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
//===-- Calculate square root of fixed point numbers. -----*- C++ -*-=========//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC___SUPPORT_FIXEDPOINT_SQRT_H
#define LLVM_LIBC_SRC___SUPPORT_FIXEDPOINT_SQRT_H

#include "include/llvm-libc-macros/stdfix-macros.h"
#include "src/__support/CPP/bit.h"
#include "src/__support/CPP/type_traits.h"
#include "src/__support/macros/attributes.h" // LIBC_INLINE
#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY

#include "fx_rep.h"

#ifdef LIBC_COMPILER_HAS_FIXED_POINT

namespace LIBC_NAMESPACE::fixed_point {

namespace internal {

template <typename T> struct SqrtConfig;

template <> struct SqrtConfig<unsigned short fract> {
using Type = unsigned short fract;
static constexpr int EXTRA_STEPS = 0;
};

template <> struct SqrtConfig<unsigned fract> {
using Type = unsigned fract;
static constexpr int EXTRA_STEPS = 1;
};

template <> struct SqrtConfig<unsigned long fract> {
using Type = unsigned long fract;
static constexpr int EXTRA_STEPS = 2;
};

template <>
struct SqrtConfig<unsigned short accum> : SqrtConfig<unsigned fract> {};

template <>
struct SqrtConfig<unsigned accum> : SqrtConfig<unsigned long fract> {};

// TODO: unsigned long accum type is 64-bit, and will need 64-bit fract type.
// Probably we will use DyadicFloat<64> for intermediate computations instead.

// Linear approximation for the initial values, with errors bounded by:
// max(1.5 * 2^-11, eps)
// Generated with Sollya:
// > for i from 4 to 15 do {
// P = fpminimax(sqrt(x), 1, [|8, 8|], [i * 2^-4, (i + 1)*2^-4],
// fixed, absolute);
// print("{", coeff(P, 1), "uhr,", coeff(P, 0), "uhr},");
// };
static constexpr unsigned short fract SQRT_FIRST_APPROX[12][2] = {
{0x1.e8p-1uhr, 0x1.0cp-2uhr}, {0x1.bap-1uhr, 0x1.28p-2uhr},
{0x1.94p-1uhr, 0x1.44p-2uhr}, {0x1.74p-1uhr, 0x1.6p-2uhr},
{0x1.6p-1uhr, 0x1.74p-2uhr}, {0x1.4ep-1uhr, 0x1.88p-2uhr},
{0x1.3ep-1uhr, 0x1.9cp-2uhr}, {0x1.32p-1uhr, 0x1.acp-2uhr},
{0x1.22p-1uhr, 0x1.c4p-2uhr}, {0x1.18p-1uhr, 0x1.d4p-2uhr},
{0x1.08p-1uhr, 0x1.fp-2uhr}, {0x1.04p-1uhr, 0x1.f8p-2uhr},
};

} // namespace internal

template <typename T>
LIBC_INLINE constexpr cpp::enable_if_t<cpp::is_fixed_point_v<T>, T> sqrt(T x) {
using BitType = typename FXRep<T>::StorageType;
BitType x_bit = cpp::bit_cast<BitType>(x);

if (LIBC_UNLIKELY(x_bit == 0))
return FXRep<T>::ZERO();

int leading_zeros = cpp::countl_zero(x_bit);
constexpr int STORAGE_LENGTH = sizeof(BitType) * CHAR_BIT;
constexpr int EXP_ADJUSTMENT = STORAGE_LENGTH - FXRep<T>::FRACTION_LEN - 1;
// x_exp is the real exponent of the leading bit of x.
int x_exp = EXP_ADJUSTMENT - leading_zeros;
int shift = EXP_ADJUSTMENT - 1 - (x_exp & (~1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that ~1 may have problems on 64 bit types since it's an int by default. Is it worthwhile to make this even in a different way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all computed in int, and all the variables and intermediates are in int so I think it should be fine. Another way to do it is (x_exp >> 1) << 1 which I think is clunkier.

// Normalize.
x_bit <<= shift;
using FracType = typename internal::SqrtConfig<T>::Type;
FracType x_frac = cpp::bit_cast<FracType>(x_bit);

// Use use Newton method to approximate sqrt(a):
// x_{n + 1} = 1/2 (x_n + a / x_n)
// For the initial values, we choose x_0

// Use the leading 4 bits to do look up for sqrt(x).
// After normalization, 0.25 <= x_frac < 1, so the leading 4 bits of x_frac
// are between 0b0100 and 0b1111. Hence the lookup table only needs 12
// entries, and we can get the index by subtracting the leading 4 bits of
// x_frac by 4 = 0b0100.
int index = (x_bit >> (STORAGE_LENGTH - 4)) - 4;
FracType a = static_cast<FracType>(internal::SQRT_FIRST_APPROX[index][0]);
FracType b = static_cast<FracType>(internal::SQRT_FIRST_APPROX[index][1]);

// Initial approximation step.
// Estimated error bounds: | r - sqrt(x_frac) | < max(1.5 * 2^-11, eps).
FracType r = a * x_frac + b;

// Further Newton-method iterations for square-root:
// x_{n + 1} = 0.5 * (x_n + a / x_n)
// We distribute and do the multiplication by 0.5 first to avoid overflow.
// TODO: Investigate the performance and accuracy of using division-free
// iterations from:
// Blanchard, J. D. and Chamberland, M., "Newton's Method Without Division",
// The American Mathematical Monthly (2023).
// https://chamberland.math.grinnell.edu/papers/newton.pdf
for (int i = 0; i < internal::SqrtConfig<T>::EXTRA_STEPS; ++i)
r = (r >> 1) + (x_frac >> 1) / r;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you remove the division, I think this could all be calculated as integer operations, removing the need for a 64 bit fract type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit tricky if we want to just use integers. Shifts and additions are totally fine, but replacement for division will involve multiplications, and for that we need to keep track of the correct bit positions, which the fixed point types are doing for us internally. Otherwise, the integer multiplications will only retain the least significant bits.

Using DyadicFloat will fit the requirement for fixed point and it is quite convenient for us, but might be a bit inefficient, since DyadicFloat will normalize the bits after every operation. I guess eventually we will implement a fixed point class with all the basic arithmetic operations, similar to FPBits and DyadicFloat, then we can simply use it instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the situation, this seems reasonable. I think Making a fixed point type comparable to DyadicFloat is the best path forward.


// Re-scaling
r >>= EXP_ADJUSTMENT - (x_exp >> 1);

// Return result.
return cpp::bit_cast<T>(r);
}

} // namespace LIBC_NAMESPACE::fixed_point

#endif // LIBC_COMPILER_HAS_FIXED_POINT

#endif // LLVM_LIBC_SRC___SUPPORT_FIXEDPOINT_SQRT_H
15 changes: 15 additions & 0 deletions libc/src/stdfix/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ foreach(suffix IN ITEMS hr r lr hk k lk)
)
endforeach()

foreach(suffix IN ITEMS uhr ur ulr uhk uk)
add_entrypoint_object(
sqrt${suffix}
HDRS
sqrt${suffix}.h
SRCS
sqrt${suffix}.cpp
COMPILE_OPTIONS
-O3
-ffixed-point
DEPENDS
libc.src.__support.fixed_point.sqrt
)
endforeach()

foreach(suffix IN ITEMS hr r lr hk k lk uhr ur ulr uhk uk ulk)
add_entrypoint_object(
round${suffix}
Expand Down
19 changes: 19 additions & 0 deletions libc/src/stdfix/sqrtuhk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- Implementation of sqrtuhk function --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "sqrtuhk.h"
#include "src/__support/common.h"
#include "src/__support/fixed_point/sqrt.h"

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(unsigned short accum, sqrtuhk, (unsigned short accum x)) {
return fixed_point::sqrt(x);
}

} // namespace LIBC_NAMESPACE
20 changes: 20 additions & 0 deletions libc/src/stdfix/sqrtuhk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===-- Implementation header for sqrtuhk -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_STDFIX_SQRTUHK_H
#define LLVM_LIBC_SRC_STDFIX_SQRTUHK_H

#include "include/llvm-libc-macros/stdfix-macros.h"

namespace LIBC_NAMESPACE {

unsigned short accum sqrtuhk(unsigned short accum x);

} // namespace LIBC_NAMESPACE

#endif // LLVM_LIBC_SRC_STDFIX_SQRTUHK_H
19 changes: 19 additions & 0 deletions libc/src/stdfix/sqrtuhr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- Implementation of sqrtuhr function --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "sqrtuhr.h"
#include "src/__support/common.h"
#include "src/__support/fixed_point/sqrt.h"

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(unsigned short fract, sqrtuhr, (unsigned short fract x)) {
return fixed_point::sqrt(x);
}

} // namespace LIBC_NAMESPACE
20 changes: 20 additions & 0 deletions libc/src/stdfix/sqrtuhr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===-- Implementation header for sqrtuhr -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_STDFIX_SQRTUHR_H
#define LLVM_LIBC_SRC_STDFIX_SQRTUHR_H

#include "include/llvm-libc-macros/stdfix-macros.h"

namespace LIBC_NAMESPACE {

unsigned short fract sqrtuhr(unsigned short fract x);

} // namespace LIBC_NAMESPACE

#endif // LLVM_LIBC_SRC_STDFIX_SQRTUHR_H
19 changes: 19 additions & 0 deletions libc/src/stdfix/sqrtuk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- Implementation of sqrtuk function ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "sqrtuk.h"
#include "src/__support/common.h"
#include "src/__support/fixed_point/sqrt.h"

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(unsigned accum, sqrtuk, (unsigned accum x)) {
return fixed_point::sqrt(x);
}

} // namespace LIBC_NAMESPACE
20 changes: 20 additions & 0 deletions libc/src/stdfix/sqrtuk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===-- Implementation header for sqrtuk ------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_STDFIX_SQRTUK_H
#define LLVM_LIBC_SRC_STDFIX_SQRTUK_H

#include "include/llvm-libc-macros/stdfix-macros.h"

namespace LIBC_NAMESPACE {

unsigned accum sqrtuk(unsigned accum x);

} // namespace LIBC_NAMESPACE

#endif // LLVM_LIBC_SRC_STDFIX_SQRTUK_H
19 changes: 19 additions & 0 deletions libc/src/stdfix/sqrtulr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- Implementation of sqrtulr function -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "sqrtulr.h"
#include "src/__support/common.h"
#include "src/__support/fixed_point/sqrt.h"

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(unsigned long fract, sqrtulr, (unsigned long fract x)) {
return fixed_point::sqrt(x);
}

} // namespace LIBC_NAMESPACE
20 changes: 20 additions & 0 deletions libc/src/stdfix/sqrtulr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===-- Implementation header for sqrtulr -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_STDFIX_SQRTULR_H
#define LLVM_LIBC_SRC_STDFIX_SQRTULR_H

#include "include/llvm-libc-macros/stdfix-macros.h"

namespace LIBC_NAMESPACE {

unsigned long fract sqrtulr(unsigned long fract x);

} // namespace LIBC_NAMESPACE

#endif // LLVM_LIBC_SRC_STDFIX_SQRTULR_H
19 changes: 19 additions & 0 deletions libc/src/stdfix/sqrtur.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- Implementation of sqrtur function ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "sqrtur.h"
#include "src/__support/common.h"
#include "src/__support/fixed_point/sqrt.h"

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(unsigned fract, sqrtur, (unsigned fract x)) {
return fixed_point::sqrt(x);
}

} // namespace LIBC_NAMESPACE
Loading