Skip to content

[flang][AIX] Handle more trig functions with complex argument to have consistent results in folding #124203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 131 additions & 79 deletions flang/lib/Evaluate/intrinsics-library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,105 +260,41 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
static_assert(map.Verify(), "map must be sorted");
};

#define COMPLEX_SIGNATURES(HOST_T) \
using F = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &>; \
using F2 = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
const std::complex<HOST_T> &>; \
using F2A = FuncPointer<std::complex<HOST_T>, const HOST_T &, \
const std::complex<HOST_T> &>; \
using F2B = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
const HOST_T &>;

#ifndef _AIX
// Helpers to map complex std::pow whose resolution in F2{std::pow} is
// ambiguous as of clang++ 20.
template <typename HostT>
static std::complex<HostT> StdPowF2(
const std::complex<HostT> &x, const std::complex<HostT> &y) {
return std::pow(x, y);
}

template <typename HostT>
static std::complex<HostT> StdPowF2A(
const HostT &x, const std::complex<HostT> &y) {
return std::pow(x, y);
}

template <typename HostT>
static std::complex<HostT> StdPowF2B(
const std::complex<HostT> &x, const HostT &y) {
return std::pow(x, y);
}

#ifdef _AIX
#ifdef __clang_major__
#pragma clang diagnostic ignored "-Wc99-extensions"
#endif

extern "C" {
float _Complex cacosf(float _Complex);
double _Complex cacos(double _Complex);
float _Complex csqrtf(float _Complex);
double _Complex csqrt(double _Complex);
}

enum CRI { Real, Imag };
template <typename TR, typename TA> static TR &reIm(TA &x, CRI n) {
return reinterpret_cast<TR(&)[2]>(x)[n];
}
template <typename TR, typename T> static TR CppToC(const std::complex<T> &x) {
TR r;
reIm<T, TR>(r, CRI::Real) = x.real();
reIm<T, TR>(r, CRI::Imag) = x.imag();
return r;
}
template <typename T, typename TA> static std::complex<T> CToCpp(const TA &x) {
TA &z{const_cast<TA &>(x)};
return std::complex<T>(reIm<T, TA>(z, CRI::Real), reIm<T, TA>(z, CRI::Imag));
}
#endif

template <typename HostT>
static std::complex<HostT> CSqrt(const std::complex<HostT> &x) {
std::complex<HostT> res;
#ifdef _AIX
// On AIX, the implementation of csqrt[f] and std::sqrt is different,
// use csqrt[f] in folding.
if constexpr (std::is_same_v<HostT, float>) {
float _Complex r{csqrtf(CppToC<float _Complex, float>(x))};
res = CToCpp<float, float _Complex>(r);
} else if constexpr (std::is_same_v<HostT, double>) {
double _Complex r{csqrt(CppToC<double _Complex, double>(x))};
res = CToCpp<double, double _Complex>(r);
} else {
DIE("bad complex component type");
}
#else
res = std::sqrt(x);
#endif
return res;
}

template <typename HostT>
static std::complex<HostT> CAcos(const std::complex<HostT> &x) {
std::complex<HostT> res;
#ifdef _AIX
// On AIX, the implementation of cacos[f] and std::acos is different,
// use cacos[f] in folding.
if constexpr (std::is_same_v<HostT, float>) {
float _Complex r{cacosf(CppToC<float _Complex, float>(x))};
res = CToCpp<float, float _Complex>(r);
} else if constexpr (std::is_same_v<HostT, double>) {
double _Complex r{cacos(CppToC<double _Complex, double>(x))};
res = CToCpp<double, double _Complex>(r);
} else {
DIE("bad complex component type");
}
#else
res = std::acos(x);
#endif
return res;
}

template <typename HostT>
struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
using F = FuncPointer<std::complex<HostT>, const std::complex<HostT> &>;
using F2 = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
const std::complex<HostT> &>;
using F2A = FuncPointer<std::complex<HostT>, const HostT &,
const std::complex<HostT> &>;
using F2B = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
const HostT &>;
COMPLEX_SIGNATURES(HostT)
static constexpr HostRuntimeFunction table[]{
FolderFactory<F, F{CAcos}>::Create("acos"),
FolderFactory<F, F{std::acos}>::Create("acos"),
FolderFactory<F, F{std::acosh}>::Create("acosh"),
FolderFactory<F, F{std::asin}>::Create("asin"),
FolderFactory<F, F{std::asinh}>::Create("asinh"),
Expand All @@ -373,13 +309,129 @@ struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
FolderFactory<F2B, F2B{StdPowF2B}>::Create("pow"),
FolderFactory<F, F{std::sin}>::Create("sin"),
FolderFactory<F, F{std::sinh}>::Create("sinh"),
FolderFactory<F, F{CSqrt}>::Create("sqrt"),
FolderFactory<F, F{std::sqrt}>::Create("sqrt"),
FolderFactory<F, F{std::tan}>::Create("tan"),
FolderFactory<F, F{std::tanh}>::Create("tanh"),
};
static constexpr HostRuntimeMap map{table};
static_assert(map.Verify(), "map must be sorted");
};
#else
// On AIX, call libm routines to preserve consistent value between
// runtime and compile time evaluation.
#ifdef __clang_major__
#pragma clang diagnostic ignored "-Wc99-extensions"
#endif

extern "C" {
float _Complex cacosf(float _Complex);
double _Complex cacos(double _Complex);
float _Complex cacoshf(float _Complex);
double _Complex cacosh(double _Complex);
float _Complex casinf(float _Complex);
double _Complex casin(double _Complex);
float _Complex casinhf(float _Complex);
double _Complex casinh(double _Complex);
float _Complex catanf(float _Complex);
double _Complex catan(double _Complex);
float _Complex catanhf(float _Complex);
double _Complex catanh(double _Complex);
float _Complex ccosf(float _Complex);
double _Complex ccos(double _Complex);
float _Complex ccoshf(float _Complex);
double _Complex ccosh(double _Complex);
float _Complex cexpf(float _Complex);
double _Complex cexp(double _Complex);
float _Complex clogf(float _Complex);
double _Complex __clog(double _Complex);
float _Complex cpowf(float _Complex, float _Complex);
double _Complex cpow(double _Complex, double _Complex);
float _Complex csinf(float _Complex);
double _Complex csin(double _Complex);
float _Complex csinhf(float _Complex);
double _Complex csinh(double _Complex);
float _Complex csqrtf(float _Complex);
double _Complex csqrt(double _Complex);
float _Complex ctanf(float _Complex);
double _Complex ctan(double _Complex);
float _Complex ctanhf(float _Complex);
double _Complex ctanh(double _Complex);
}

template <typename T> struct ToStdComplex {
using Type = T;
using AType = Type;
};
template <> struct ToStdComplex<float _Complex> {
using Type = std::complex<float>;
using AType = const Type &;
};
template <> struct ToStdComplex<double _Complex> {
using Type = std::complex<double>;
using AType = const Type &;
};

template <typename F, F func> struct CComplexFunc {};
template <typename R, typename... A, FuncPointer<R, A...> func>
struct CComplexFunc<FuncPointer<R, A...>, func> {
static typename ToStdComplex<R>::Type wrapper(
typename ToStdComplex<A>::AType... args) {
R res{func(*reinterpret_cast<const A *>(&args)...)};
return *reinterpret_cast<typename ToStdComplex<R>::Type *>(&res);
}
};
#define C_COMPLEX_FUNC(func) CComplexFunc<decltype(&func), &func>::wrapper

template <>
struct HostRuntimeLibrary<std::complex<float>, LibraryVersion::Libm> {
COMPLEX_SIGNATURES(float)
static constexpr HostRuntimeFunction table[]{
FolderFactory<F, C_COMPLEX_FUNC(cacosf)>::Create("acos"),
FolderFactory<F, C_COMPLEX_FUNC(cacoshf)>::Create("acosh"),
FolderFactory<F, C_COMPLEX_FUNC(casinf)>::Create("asin"),
FolderFactory<F, C_COMPLEX_FUNC(casinhf)>::Create("asinh"),
FolderFactory<F, C_COMPLEX_FUNC(catanf)>::Create("atan"),
FolderFactory<F, C_COMPLEX_FUNC(catanhf)>::Create("atanh"),
FolderFactory<F, C_COMPLEX_FUNC(ccosf)>::Create("cos"),
FolderFactory<F, C_COMPLEX_FUNC(ccoshf)>::Create("cosh"),
FolderFactory<F, C_COMPLEX_FUNC(cexpf)>::Create("exp"),
FolderFactory<F, C_COMPLEX_FUNC(clogf)>::Create("log"),
FolderFactory<F2, C_COMPLEX_FUNC(cpowf)>::Create("pow"),
FolderFactory<F, C_COMPLEX_FUNC(csinf)>::Create("sin"),
FolderFactory<F, C_COMPLEX_FUNC(csinhf)>::Create("sinh"),
FolderFactory<F, C_COMPLEX_FUNC(csqrtf)>::Create("sqrt"),
FolderFactory<F, C_COMPLEX_FUNC(ctanf)>::Create("tan"),
FolderFactory<F, C_COMPLEX_FUNC(ctanhf)>::Create("tanh"),
};
static constexpr HostRuntimeMap map{table};
static_assert(map.Verify(), "map must be sorted");
};
template <>
struct HostRuntimeLibrary<std::complex<double>, LibraryVersion::Libm> {
COMPLEX_SIGNATURES(double)
static constexpr HostRuntimeFunction table[]{
FolderFactory<F, C_COMPLEX_FUNC(cacos)>::Create("acos"),
FolderFactory<F, C_COMPLEX_FUNC(cacosh)>::Create("acosh"),
FolderFactory<F, C_COMPLEX_FUNC(casin)>::Create("asin"),
FolderFactory<F, C_COMPLEX_FUNC(casinh)>::Create("asinh"),
FolderFactory<F, C_COMPLEX_FUNC(catan)>::Create("atan"),
FolderFactory<F, C_COMPLEX_FUNC(catanh)>::Create("atanh"),
FolderFactory<F, C_COMPLEX_FUNC(ccos)>::Create("cos"),
FolderFactory<F, C_COMPLEX_FUNC(ccosh)>::Create("cosh"),
FolderFactory<F, C_COMPLEX_FUNC(cexp)>::Create("exp"),
FolderFactory<F, C_COMPLEX_FUNC(__clog)>::Create("log"),
FolderFactory<F2, C_COMPLEX_FUNC(cpow)>::Create("pow"),
FolderFactory<F, C_COMPLEX_FUNC(csin)>::Create("sin"),
FolderFactory<F, C_COMPLEX_FUNC(csinh)>::Create("sinh"),
FolderFactory<F, C_COMPLEX_FUNC(csqrt)>::Create("sqrt"),
FolderFactory<F, C_COMPLEX_FUNC(ctan)>::Create("tan"),
FolderFactory<F, C_COMPLEX_FUNC(ctanh)>::Create("tanh"),
};
static constexpr HostRuntimeMap map{table};
static_assert(map.Verify(), "map must be sorted");
};
#endif // _AIX

// Note regarding cmath:
// - cmath does not have modulo and erfc_scaled equivalent
// - C++17 defined standard Bessel math functions std::cyl_bessel_j
Expand Down
Loading