Skip to content

[flang][AIX] Handle more trig functions with complex argument to have consistent results in folding #124203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 30, 2025

Conversation

kkwli
Copy link
Collaborator

@kkwli kkwli commented Jan 23, 2025

This patch is to extend 71d4f34 to include all trig functions that allow arguments of complex type. On AIX, the libm routines are called in compile time folding instead of the STL routines.

@kkwli kkwli self-assigned this Jan 23, 2025
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:semantics labels Jan 23, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 23, 2025

@llvm/pr-subscribers-flang-semantics

Author: Kelvin Li (kkwli)

Changes

This patch is to extend 71d4f34 to include all trig functions that allow arguments of complex type. On AIX, the libm routines are called in compile time folding instead of the STL routines.


Full diff: https://github.com/llvm/llvm-project/pull/124203.diff

1 Files Affected:

  • (modified) flang/lib/Evaluate/intrinsics-library.cpp (+182-56)
diff --git a/flang/lib/Evaluate/intrinsics-library.cpp b/flang/lib/Evaluate/intrinsics-library.cpp
index c1b270f518c0e0..60bb2785c725fd 100644
--- a/flang/lib/Evaluate/intrinsics-library.cpp
+++ b/flang/lib/Evaluate/intrinsics-library.cpp
@@ -260,23 +260,23 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
   static_assert(map.Verify(), "map must be sorted");
 };
 
-// Helpers to map complex std::pow whose resolution in F2{std::pow} is
-// ambiguous as of clang++ 20.
-template <typename HostT>
-static std::complex<HostT> StdPowF2(
-    const std::complex<HostT> &x, const std::complex<HostT> &y) {
-  return std::pow(x, y);
-}
-template <typename HostT>
-static std::complex<HostT> StdPowF2A(
-    const HostT &x, const std::complex<HostT> &y) {
-  return std::pow(x, y);
-}
-template <typename HostT>
-static std::complex<HostT> StdPowF2B(
-    const std::complex<HostT> &x, const HostT &y) {
-  return std::pow(x, y);
-}
+enum trigFunc {
+  Cacos,
+  Cacosh,
+  Casin,
+  Casinh,
+  Catan,
+  Catanh,
+  Ccos,
+  Ccosh,
+  Cexp,
+  Clog,
+  Csin,
+  Csinh,
+  Csqrt,
+  Ctan,
+  Ctanh
+};
 
 #ifdef _AIX
 #ifdef __clang_major__
@@ -286,8 +286,36 @@ static std::complex<HostT> StdPowF2B(
 extern "C" {
 float _Complex cacosf(float _Complex);
 double _Complex cacos(double _Complex);
+float _Complex cacoshf(float _Complex);
+double _Complex cacosh(double _Complex);
+float _Complex casinf(float _Complex);
+double _Complex casin(double _Complex);
+float _Complex casinhf(float _Complex);
+double _Complex casinh(double _Complex);
+float _Complex catanf(float _Complex);
+double _Complex catan(double _Complex);
+float _Complex catanhf(float _Complex);
+double _Complex catanh(double _Complex);
+float _Complex ccosf(float _Complex);
+double _Complex ccos(double _Complex);
+float _Complex ccoshf(float _Complex);
+double _Complex ccosh(double _Complex);
+float _Complex cexpf(float _Complex);
+double _Complex cexp(double _Complex);
+float _Complex clogf(float _Complex);
+double _Complex __clog(double _Complex);
+float _Complex cpowf(float _Complex, float _Complex);
+double _Complex cpow(double _Complex, double _Complex);
+float _Complex csinf(float _Complex);
+double _Complex csin(double _Complex);
+float _Complex csinhf(float _Complex);
+double _Complex csinh(double _Complex);
 float _Complex csqrtf(float _Complex);
 double _Complex csqrt(double _Complex);
+float _Complex ctanf(float _Complex);
+double _Complex ctan(double _Complex);
+float _Complex ctanhf(float _Complex);
+double _Complex ctanh(double _Complex);
 }
 
 enum CRI { Real, Imag };
@@ -304,48 +332,146 @@ template <typename T, typename TA> static std::complex<T> CToCpp(const TA &x) {
   TA &z{const_cast<TA &>(x)};
   return std::complex<T>(reIm<T, TA>(z, CRI::Real), reIm<T, TA>(z, CRI::Imag));
 }
+
+using FTypeCmplxFlt = _Complex float (*)(_Complex float);
+using FTypeCmplxDble = _Complex double (*)(_Complex double);
+template <typename T>
+using FTypeStdCmplx = std::complex<T> (*)(const std::complex<T>&);
+
+std::map<trigFunc, std::tuple<FTypeCmplxFlt, FTypeCmplxDble>> mapLibmTrigFunc{
+    {Cacos, {&cacosf, &cacos}}, {Cacosh, {&cacoshf, &cacosh}},
+    {Casin, {&casinf, &casin}}, {Casinh, {&casinhf, &casinh}},
+    {Catan, {&catanf, &catan}}, {Catanh, {&catanhf, &catanh}},
+    {Ccos, {&ccosf, &ccos}}, {Ccosh, {&ccoshf, &ccosh}},
+    {Cexp, {&cexpf, &cexp}}, {Clog, {&clogf, &__clog}},
+    {Csin, {&csinf, &csin}}, {Csinh, {&csinhf, &csinh}},
+    {Csqrt, {&csqrtf, &csqrt}}, {Ctan, {&ctanf, &ctan}},
+    {Ctanh, {&ctanhf, &ctanh}}};
+
+template <trigFunc TF, typename HostT>
+std::complex<HostT> LibmTrigFunc(const std::complex<HostT> &x) {
+  if constexpr (std::is_same_v<HostT, float>) {
+    float _Complex r{
+      std::get<FTypeCmplxFlt>(mapLibmTrigFunc[TF])(CppToC<float _Complex, float>(x))};
+    return CToCpp<float, float _Complex>(r);
+  } else if constexpr (std::is_same_v<HostT, double>) {
+    double _Complex r{
+      std::get<FTypeCmplxDble>(mapLibmTrigFunc[TF])(CppToC<double _Complex, double>(x))};
+    return CToCpp<double, double _Complex>(r);
+  }
+  DIE("bad complex component type");
+}
+#endif
+
+template <trigFunc TF, typename HostT>
+std::complex<HostT> StdTrigFunc(const std::complex<HostT> &x) {
+  if constexpr (TF == Cacos) {
+    return std::acos(x);
+  } else if constexpr (TF == Cacosh) {
+    return std::acosh(x);
+  } else if constexpr (TF == Casin) {
+    return std::asin(x);
+  } else if constexpr (TF == Casinh) {
+    return std::asinh(x);
+  } else if constexpr (TF == Catan) {
+    return std::atan(x);
+  } else if constexpr (TF == Catanh) {
+    return std::atanh(x);
+  } else if constexpr (TF == Ccos) {
+    return std::cos(x);
+  } else if constexpr (TF == Ccosh) {
+    return std::cosh(x);
+  } else if constexpr (TF == Cexp) {
+    return std::exp(x);
+  } else if constexpr (TF == Clog) {
+    return std::log(x);
+  } else if constexpr (TF == Csin) {
+    return std::sin(x);
+  } else if constexpr (TF == Csinh) {
+    return std::sinh(x);
+  } else if constexpr (TF == Csqrt) {
+    return std::sqrt(x);
+  } else if constexpr (TF == Ctan) {
+    return std::tan(x);
+  } else if constexpr (TF == Ctanh) {
+    return std::tanh(x);
+  }
+  DIE("unknown function");
+}
+
+template <trigFunc TF> struct X {
+  template <typename HostT>
+  static std::complex<HostT> f(const std::complex<HostT> &x) {
+    std::complex<HostT> res;
+#ifdef _AIX
+    // On AIX, the implementation in libm is different from that of std::
+    // routines, use the libm routines here in folding for consistent results.
+    res = LibmTrigFunc<TF>(x);
+#else
+    res = StdTrigFunc<TF, HostT>(x);
+#endif
+    return res;
+  }
+};
+
+// Helpers to map complex std::pow whose resolution in F2{std::pow} is
+// ambiguous as of clang++ 20.
+template <typename HostT>
+static std::complex<HostT> StdPowF2(const std::complex<HostT> &x,
+                                    const std::complex<HostT> &y) {
+#ifdef _AIX
+  if constexpr (std::is_same_v<HostT, float>) {
+    float _Complex r{cpowf(CppToC<float _Complex, float>(x),
+                           CppToC<float _Complex, float>(y))};
+    return CToCpp<float, float _Complex>(r);
+  } else if constexpr (std::is_same_v<HostT, double>) {
+    double _Complex r{cpow(CppToC<double _Complex, double>(x),
+                           CppToC<double _Complex, double>(y))};
+    return CToCpp<double, double _Complex>(r);
+  }
+#else
+  return std::pow(x, y);
 #endif
+}
 
 template <typename HostT>
-static std::complex<HostT> CSqrt(const std::complex<HostT> &x) {
-  std::complex<HostT> res;
+static std::complex<HostT> StdPowF2A(const HostT &x,
+                                     const std::complex<HostT> &y) {
 #ifdef _AIX
-  // On AIX, the implementation of csqrt[f] and std::sqrt is different,
-  // use csqrt[f] in folding.
+  constexpr HostT zero{0.0};
+  std::complex<HostT> z(x, zero);
   if constexpr (std::is_same_v<HostT, float>) {
-    float _Complex r{csqrtf(CppToC<float _Complex, float>(x))};
-    res = CToCpp<float, float _Complex>(r);
+    float _Complex r{cpowf(CppToC<float _Complex, float>(z),
+                           CppToC<float _Complex, float>(y))};
+    return CToCpp<float, float _Complex>(r);
   } else if constexpr (std::is_same_v<HostT, double>) {
-    double _Complex r{csqrt(CppToC<double _Complex, double>(x))};
-    res = CToCpp<double, double _Complex>(r);
-  } else {
-    DIE("bad complex component type");
+    double _Complex r{cpow(CppToC<double _Complex, double>(z),
+                           CppToC<double _Complex, double>(y))};
+    return CToCpp<double, double _Complex>(r);
   }
 #else
-  res = std::sqrt(x);
+  return std::pow(x, y);
 #endif
-  return res;
 }
 
 template <typename HostT>
-static std::complex<HostT> CAcos(const std::complex<HostT> &x) {
-  std::complex<HostT> res;
+static std::complex<HostT> StdPowF2B(const std::complex<HostT> &x,
+                                     const HostT &y) {
 #ifdef _AIX
-  // On AIX, the implementation of cacos[f] and std::acos is different,
-  // use cacos[f] in folding.
+  constexpr HostT zero{0.0};
+  std::complex<HostT> z(y, zero);
   if constexpr (std::is_same_v<HostT, float>) {
-    float _Complex r{cacosf(CppToC<float _Complex, float>(x))};
-    res = CToCpp<float, float _Complex>(r);
+    float _Complex r{cpowf(CppToC<float _Complex, float>(x),
+                           CppToC<float _Complex, float>(z))};
+    return CToCpp<float, float _Complex>(r);
   } else if constexpr (std::is_same_v<HostT, double>) {
-    double _Complex r{cacos(CppToC<double _Complex, double>(x))};
-    res = CToCpp<double, double _Complex>(r);
-  } else {
-    DIE("bad complex component type");
+    double _Complex r{cpow(CppToC<double _Complex, double>(x),
+                           CppToC<double _Complex, double>(z))};
+    return CToCpp<double, double _Complex>(r);
   }
 #else
-  res = std::acos(x);
+  return std::pow(x, y);
 #endif
-  return res;
 }
 
 template <typename HostT>
@@ -358,24 +484,24 @@ struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
   using F2B = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
       const HostT &>;
   static constexpr HostRuntimeFunction table[]{
-      FolderFactory<F, F{CAcos}>::Create("acos"),
-      FolderFactory<F, F{std::acosh}>::Create("acosh"),
-      FolderFactory<F, F{std::asin}>::Create("asin"),
-      FolderFactory<F, F{std::asinh}>::Create("asinh"),
-      FolderFactory<F, F{std::atan}>::Create("atan"),
-      FolderFactory<F, F{std::atanh}>::Create("atanh"),
-      FolderFactory<F, F{std::cos}>::Create("cos"),
-      FolderFactory<F, F{std::cosh}>::Create("cosh"),
-      FolderFactory<F, F{std::exp}>::Create("exp"),
-      FolderFactory<F, F{std::log}>::Create("log"),
+      FolderFactory<F, F{X<Cacos>::f}>::Create("acos"),
+      FolderFactory<F, F{X<Cacosh>::f}>::Create("acosh"),
+      FolderFactory<F, F{X<Casin>::f}>::Create("asin"),
+      FolderFactory<F, F{X<Casinh>::f}>::Create("asinh"),
+      FolderFactory<F, F{X<Catan>::f}>::Create("atan"),
+      FolderFactory<F, F{X<Catanh>::f}>::Create("atanh"),
+      FolderFactory<F, F{X<Ccos>::f}>::Create("cos"),
+      FolderFactory<F, F{X<Ccosh>::f}>::Create("cosh"),
+      FolderFactory<F, F{X<Cexp>::f}>::Create("exp"),
+      FolderFactory<F, F{X<Clog>::f}>::Create("log"),
       FolderFactory<F2, F2{StdPowF2}>::Create("pow"),
       FolderFactory<F2A, F2A{StdPowF2A}>::Create("pow"),
       FolderFactory<F2B, F2B{StdPowF2B}>::Create("pow"),
-      FolderFactory<F, F{std::sin}>::Create("sin"),
-      FolderFactory<F, F{std::sinh}>::Create("sinh"),
-      FolderFactory<F, F{CSqrt}>::Create("sqrt"),
-      FolderFactory<F, F{std::tan}>::Create("tan"),
-      FolderFactory<F, F{std::tanh}>::Create("tanh"),
+      FolderFactory<F, F{X<Csin>::f}>::Create("sin"),
+      FolderFactory<F, F{X<Csinh>::f}>::Create("sinh"),
+      FolderFactory<F, F{X<Csqrt>::f}>::Create("sqrt"),
+      FolderFactory<F, F{X<Ctan>::f}>::Create("tan"),
+      FolderFactory<F, F{X<Ctanh>::f}>::Create("tanh"),
   };
   static constexpr HostRuntimeMap map{table};
   static_assert(map.Verify(), "map must be sorted");

Copy link

github-actions bot commented Jan 23, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@kkwli kkwli requested a review from jeanPerier January 26, 2025 05:12
Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing my comments.
I do not think the windows failure is yours, but you may want to rebase to make sure.

@kkwli kkwli merged commit a8d4335 into llvm:main Jan 30, 2025
8 checks passed
@kkwli kkwli deleted the c-trig branch January 30, 2025 01:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants