Skip to content

[flang][AIX] Handle more trig functions with complex argument to have consistent results in folding #124203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 30, 2025
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 181 additions & 56 deletions flang/lib/Evaluate/intrinsics-library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,23 +260,23 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
static_assert(map.Verify(), "map must be sorted");
};

// Helpers to map complex std::pow whose resolution in F2{std::pow} is
// ambiguous as of clang++ 20.
template <typename HostT>
static std::complex<HostT> StdPowF2(
const std::complex<HostT> &x, const std::complex<HostT> &y) {
return std::pow(x, y);
}
template <typename HostT>
static std::complex<HostT> StdPowF2A(
const HostT &x, const std::complex<HostT> &y) {
return std::pow(x, y);
}
template <typename HostT>
static std::complex<HostT> StdPowF2B(
const std::complex<HostT> &x, const HostT &y) {
return std::pow(x, y);
}
enum trigFunc {
Cacos,
Cacosh,
Casin,
Casinh,
Catan,
Catanh,
Ccos,
Ccosh,
Cexp,
Clog,
Csin,
Csinh,
Csqrt,
Ctan,
Ctanh
};

#ifdef _AIX
#ifdef __clang_major__
Expand All @@ -286,8 +286,36 @@ static std::complex<HostT> StdPowF2B(
extern "C" {
float _Complex cacosf(float _Complex);
double _Complex cacos(double _Complex);
float _Complex cacoshf(float _Complex);
double _Complex cacosh(double _Complex);
float _Complex casinf(float _Complex);
double _Complex casin(double _Complex);
float _Complex casinhf(float _Complex);
double _Complex casinh(double _Complex);
float _Complex catanf(float _Complex);
double _Complex catan(double _Complex);
float _Complex catanhf(float _Complex);
double _Complex catanh(double _Complex);
float _Complex ccosf(float _Complex);
double _Complex ccos(double _Complex);
float _Complex ccoshf(float _Complex);
double _Complex ccosh(double _Complex);
float _Complex cexpf(float _Complex);
double _Complex cexp(double _Complex);
float _Complex clogf(float _Complex);
double _Complex __clog(double _Complex);
float _Complex cpowf(float _Complex, float _Complex);
double _Complex cpow(double _Complex, double _Complex);
float _Complex csinf(float _Complex);
double _Complex csin(double _Complex);
float _Complex csinhf(float _Complex);
double _Complex csinh(double _Complex);
float _Complex csqrtf(float _Complex);
double _Complex csqrt(double _Complex);
float _Complex ctanf(float _Complex);
double _Complex ctan(double _Complex);
float _Complex ctanhf(float _Complex);
double _Complex ctanh(double _Complex);
}

enum CRI { Real, Imag };
Expand All @@ -304,48 +332,145 @@ template <typename T, typename TA> static std::complex<T> CToCpp(const TA &x) {
TA &z{const_cast<TA &>(x)};
return std::complex<T>(reIm<T, TA>(z, CRI::Real), reIm<T, TA>(z, CRI::Imag));
}

using FTypeCmplxFlt = _Complex float (*)(_Complex float);
using FTypeCmplxDble = _Complex double (*)(_Complex double);
template <typename T>
using FTypeStdCmplx = std::complex<T> (*)(const std::complex<T> &);

std::map<trigFunc, std::tuple<FTypeCmplxFlt, FTypeCmplxDble>> mapLibmTrigFunc{
{Cacos, {&cacosf, &cacos}}, {Cacosh, {&cacoshf, &cacosh}},
{Casin, {&casinf, &casin}}, {Casinh, {&casinhf, &casinh}},
{Catan, {&catanf, &catan}}, {Catanh, {&catanhf, &catanh}},
{Ccos, {&ccosf, &ccos}}, {Ccosh, {&ccoshf, &ccosh}},
{Cexp, {&cexpf, &cexp}}, {Clog, {&clogf, &__clog}}, {Csin, {&csinf, &csin}},
{Csinh, {&csinhf, &csinh}}, {Csqrt, {&csqrtf, &csqrt}},
{Ctan, {&ctanf, &ctan}}, {Ctanh, {&ctanhf, &ctanh}}};

template <trigFunc TF, typename HostT>
std::complex<HostT> LibmTrigFunc(const std::complex<HostT> &x) {
if constexpr (std::is_same_v<HostT, float>) {
float _Complex r{std::get<FTypeCmplxFlt>(mapLibmTrigFunc[TF])(
CppToC<float _Complex, float>(x))};
return CToCpp<float, float _Complex>(r);
} else if constexpr (std::is_same_v<HostT, double>) {
double _Complex r{std::get<FTypeCmplxDble>(mapLibmTrigFunc[TF])(
CppToC<double _Complex, double>(x))};
return CToCpp<double, double _Complex>(r);
}
DIE("bad complex component type");
}
#endif

template <trigFunc TF, typename HostT>
std::complex<HostT> StdTrigFunc(const std::complex<HostT> &x) {
if constexpr (TF == Cacos) {
return std::acos(x);
} else if constexpr (TF == Cacosh) {
return std::acosh(x);
} else if constexpr (TF == Casin) {
return std::asin(x);
} else if constexpr (TF == Casinh) {
return std::asinh(x);
} else if constexpr (TF == Catan) {
return std::atan(x);
} else if constexpr (TF == Catanh) {
return std::atanh(x);
} else if constexpr (TF == Ccos) {
return std::cos(x);
} else if constexpr (TF == Ccosh) {
return std::cosh(x);
} else if constexpr (TF == Cexp) {
return std::exp(x);
} else if constexpr (TF == Clog) {
return std::log(x);
} else if constexpr (TF == Csin) {
return std::sin(x);
} else if constexpr (TF == Csinh) {
return std::sinh(x);
} else if constexpr (TF == Csqrt) {
return std::sqrt(x);
} else if constexpr (TF == Ctan) {
return std::tan(x);
} else if constexpr (TF == Ctanh) {
return std::tanh(x);
}
DIE("unknown function");
}

template <trigFunc TF> struct X {
template <typename HostT>
static std::complex<HostT> f(const std::complex<HostT> &x) {
std::complex<HostT> res;
#ifdef _AIX
// On AIX, the implementation in libm is different from that of the STL
// routines, use the libm routines here in folding for consistent results.
res = LibmTrigFunc<TF>(x);
#else
res = StdTrigFunc<TF, HostT>(x);
#endif
return res;
}
};

// Helpers to map complex std::pow whose resolution in F2{std::pow} is
// ambiguous as of clang++ 20.
template <typename HostT>
static std::complex<HostT> StdPowF2(
const std::complex<HostT> &x, const std::complex<HostT> &y) {
#ifdef _AIX
if constexpr (std::is_same_v<HostT, float>) {
float _Complex r{cpowf(
CppToC<float _Complex, float>(x), CppToC<float _Complex, float>(y))};
return CToCpp<float, float _Complex>(r);
} else if constexpr (std::is_same_v<HostT, double>) {
double _Complex r{cpow(CppToC<double _Complex, double>(x),
CppToC<double _Complex, double>(y))};
return CToCpp<double, double _Complex>(r);
}
#else
return std::pow(x, y);
#endif
}

template <typename HostT>
static std::complex<HostT> CSqrt(const std::complex<HostT> &x) {
std::complex<HostT> res;
static std::complex<HostT> StdPowF2A(
const HostT &x, const std::complex<HostT> &y) {
#ifdef _AIX
// On AIX, the implementation of csqrt[f] and std::sqrt is different,
// use csqrt[f] in folding.
constexpr HostT zero{0.0};
std::complex<HostT> z(x, zero);
if constexpr (std::is_same_v<HostT, float>) {
float _Complex r{csqrtf(CppToC<float _Complex, float>(x))};
res = CToCpp<float, float _Complex>(r);
float _Complex r{cpowf(
CppToC<float _Complex, float>(z), CppToC<float _Complex, float>(y))};
return CToCpp<float, float _Complex>(r);
} else if constexpr (std::is_same_v<HostT, double>) {
double _Complex r{csqrt(CppToC<double _Complex, double>(x))};
res = CToCpp<double, double _Complex>(r);
} else {
DIE("bad complex component type");
double _Complex r{cpow(CppToC<double _Complex, double>(z),
CppToC<double _Complex, double>(y))};
return CToCpp<double, double _Complex>(r);
}
#else
res = std::sqrt(x);
return std::pow(x, y);
#endif
return res;
}

template <typename HostT>
static std::complex<HostT> CAcos(const std::complex<HostT> &x) {
std::complex<HostT> res;
static std::complex<HostT> StdPowF2B(
const std::complex<HostT> &x, const HostT &y) {
#ifdef _AIX
// On AIX, the implementation of cacos[f] and std::acos is different,
// use cacos[f] in folding.
constexpr HostT zero{0.0};
std::complex<HostT> z(y, zero);
if constexpr (std::is_same_v<HostT, float>) {
float _Complex r{cacosf(CppToC<float _Complex, float>(x))};
res = CToCpp<float, float _Complex>(r);
float _Complex r{cpowf(
CppToC<float _Complex, float>(x), CppToC<float _Complex, float>(z))};
return CToCpp<float, float _Complex>(r);
} else if constexpr (std::is_same_v<HostT, double>) {
double _Complex r{cacos(CppToC<double _Complex, double>(x))};
res = CToCpp<double, double _Complex>(r);
} else {
DIE("bad complex component type");
double _Complex r{cpow(CppToC<double _Complex, double>(x),
CppToC<double _Complex, double>(z))};
return CToCpp<double, double _Complex>(r);
}
#else
res = std::acos(x);
return std::pow(x, y);
#endif
return res;
}

template <typename HostT>
Expand All @@ -358,24 +483,24 @@ struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
using F2B = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
const HostT &>;
static constexpr HostRuntimeFunction table[]{
FolderFactory<F, F{CAcos}>::Create("acos"),
FolderFactory<F, F{std::acosh}>::Create("acosh"),
FolderFactory<F, F{std::asin}>::Create("asin"),
FolderFactory<F, F{std::asinh}>::Create("asinh"),
FolderFactory<F, F{std::atan}>::Create("atan"),
FolderFactory<F, F{std::atanh}>::Create("atanh"),
FolderFactory<F, F{std::cos}>::Create("cos"),
FolderFactory<F, F{std::cosh}>::Create("cosh"),
FolderFactory<F, F{std::exp}>::Create("exp"),
FolderFactory<F, F{std::log}>::Create("log"),
FolderFactory<F, F{X<Cacos>::f}>::Create("acos"),
FolderFactory<F, F{X<Cacosh>::f}>::Create("acosh"),
FolderFactory<F, F{X<Casin>::f}>::Create("asin"),
FolderFactory<F, F{X<Casinh>::f}>::Create("asinh"),
FolderFactory<F, F{X<Catan>::f}>::Create("atan"),
FolderFactory<F, F{X<Catanh>::f}>::Create("atanh"),
FolderFactory<F, F{X<Ccos>::f}>::Create("cos"),
FolderFactory<F, F{X<Ccosh>::f}>::Create("cosh"),
FolderFactory<F, F{X<Cexp>::f}>::Create("exp"),
FolderFactory<F, F{X<Clog>::f}>::Create("log"),
FolderFactory<F2, F2{StdPowF2}>::Create("pow"),
FolderFactory<F2A, F2A{StdPowF2A}>::Create("pow"),
FolderFactory<F2B, F2B{StdPowF2B}>::Create("pow"),
FolderFactory<F, F{std::sin}>::Create("sin"),
FolderFactory<F, F{std::sinh}>::Create("sinh"),
FolderFactory<F, F{CSqrt}>::Create("sqrt"),
FolderFactory<F, F{std::tan}>::Create("tan"),
FolderFactory<F, F{std::tanh}>::Create("tanh"),
FolderFactory<F, F{X<Csin>::f}>::Create("sin"),
FolderFactory<F, F{X<Csinh>::f}>::Create("sinh"),
FolderFactory<F, F{X<Csqrt>::f}>::Create("sqrt"),
FolderFactory<F, F{X<Ctan>::f}>::Create("tan"),
FolderFactory<F, F{X<Ctanh>::f}>::Create("tanh"),
};
static constexpr HostRuntimeMap map{table};
static_assert(map.Verify(), "map must be sorted");
Expand Down
Loading