Skip to content

Commit a8d4335

Browse files
authored
[flang][AIX] Handle more trig functions with complex argument to have consistent results in folding (#124203)
This patch extends 71d4f34 to all trig functions that take complex arguments. On AIX, the `libm` routines are called in compile time folding instead of the STL routines.
1 parent 00c096e commit a8d4335

File tree

1 file changed

+131
-79
lines changed

1 file changed

+131
-79
lines changed

flang/lib/Evaluate/intrinsics-library.cpp

Lines changed: 131 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -260,105 +260,41 @@ struct HostRuntimeLibrary<HostT, LibraryVersion::Libm> {
260260
static_assert(map.Verify(), "map must be sorted");
261261
};
262262

263+
#define COMPLEX_SIGNATURES(HOST_T) \
264+
using F = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &>; \
265+
using F2 = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
266+
const std::complex<HOST_T> &>; \
267+
using F2A = FuncPointer<std::complex<HOST_T>, const HOST_T &, \
268+
const std::complex<HOST_T> &>; \
269+
using F2B = FuncPointer<std::complex<HOST_T>, const std::complex<HOST_T> &, \
270+
const HOST_T &>;
271+
272+
#ifndef _AIX
263273
// Helpers to map complex std::pow whose resolution in F2{std::pow} is
264274
// ambiguous as of clang++ 20.
265275
template <typename HostT>
266276
static std::complex<HostT> StdPowF2(
267277
const std::complex<HostT> &x, const std::complex<HostT> &y) {
268278
return std::pow(x, y);
269279
}
280+
270281
template <typename HostT>
271282
static std::complex<HostT> StdPowF2A(
272283
const HostT &x, const std::complex<HostT> &y) {
273284
return std::pow(x, y);
274285
}
286+
275287
template <typename HostT>
276288
static std::complex<HostT> StdPowF2B(
277289
const std::complex<HostT> &x, const HostT &y) {
278290
return std::pow(x, y);
279291
}
280292

281-
#ifdef _AIX
282-
#ifdef __clang_major__
283-
#pragma clang diagnostic ignored "-Wc99-extensions"
284-
#endif
285-
286-
extern "C" {
287-
float _Complex cacosf(float _Complex);
288-
double _Complex cacos(double _Complex);
289-
float _Complex csqrtf(float _Complex);
290-
double _Complex csqrt(double _Complex);
291-
}
292-
293-
enum CRI { Real, Imag };
294-
template <typename TR, typename TA> static TR &reIm(TA &x, CRI n) {
295-
return reinterpret_cast<TR(&)[2]>(x)[n];
296-
}
297-
template <typename TR, typename T> static TR CppToC(const std::complex<T> &x) {
298-
TR r;
299-
reIm<T, TR>(r, CRI::Real) = x.real();
300-
reIm<T, TR>(r, CRI::Imag) = x.imag();
301-
return r;
302-
}
303-
template <typename T, typename TA> static std::complex<T> CToCpp(const TA &x) {
304-
TA &z{const_cast<TA &>(x)};
305-
return std::complex<T>(reIm<T, TA>(z, CRI::Real), reIm<T, TA>(z, CRI::Imag));
306-
}
307-
#endif
308-
309-
template <typename HostT>
310-
static std::complex<HostT> CSqrt(const std::complex<HostT> &x) {
311-
std::complex<HostT> res;
312-
#ifdef _AIX
313-
// On AIX, the implementation of csqrt[f] and std::sqrt is different,
314-
// use csqrt[f] in folding.
315-
if constexpr (std::is_same_v<HostT, float>) {
316-
float _Complex r{csqrtf(CppToC<float _Complex, float>(x))};
317-
res = CToCpp<float, float _Complex>(r);
318-
} else if constexpr (std::is_same_v<HostT, double>) {
319-
double _Complex r{csqrt(CppToC<double _Complex, double>(x))};
320-
res = CToCpp<double, double _Complex>(r);
321-
} else {
322-
DIE("bad complex component type");
323-
}
324-
#else
325-
res = std::sqrt(x);
326-
#endif
327-
return res;
328-
}
329-
330-
template <typename HostT>
331-
static std::complex<HostT> CAcos(const std::complex<HostT> &x) {
332-
std::complex<HostT> res;
333-
#ifdef _AIX
334-
// On AIX, the implementation of cacos[f] and std::acos is different,
335-
// use cacos[f] in folding.
336-
if constexpr (std::is_same_v<HostT, float>) {
337-
float _Complex r{cacosf(CppToC<float _Complex, float>(x))};
338-
res = CToCpp<float, float _Complex>(r);
339-
} else if constexpr (std::is_same_v<HostT, double>) {
340-
double _Complex r{cacos(CppToC<double _Complex, double>(x))};
341-
res = CToCpp<double, double _Complex>(r);
342-
} else {
343-
DIE("bad complex component type");
344-
}
345-
#else
346-
res = std::acos(x);
347-
#endif
348-
return res;
349-
}
350-
351293
template <typename HostT>
352294
struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
353-
using F = FuncPointer<std::complex<HostT>, const std::complex<HostT> &>;
354-
using F2 = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
355-
const std::complex<HostT> &>;
356-
using F2A = FuncPointer<std::complex<HostT>, const HostT &,
357-
const std::complex<HostT> &>;
358-
using F2B = FuncPointer<std::complex<HostT>, const std::complex<HostT> &,
359-
const HostT &>;
295+
COMPLEX_SIGNATURES(HostT)
360296
static constexpr HostRuntimeFunction table[]{
361-
FolderFactory<F, F{CAcos}>::Create("acos"),
297+
FolderFactory<F, F{std::acos}>::Create("acos"),
362298
FolderFactory<F, F{std::acosh}>::Create("acosh"),
363299
FolderFactory<F, F{std::asin}>::Create("asin"),
364300
FolderFactory<F, F{std::asinh}>::Create("asinh"),
@@ -373,13 +309,129 @@ struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
373309
FolderFactory<F2B, F2B{StdPowF2B}>::Create("pow"),
374310
FolderFactory<F, F{std::sin}>::Create("sin"),
375311
FolderFactory<F, F{std::sinh}>::Create("sinh"),
376-
FolderFactory<F, F{CSqrt}>::Create("sqrt"),
312+
FolderFactory<F, F{std::sqrt}>::Create("sqrt"),
377313
FolderFactory<F, F{std::tan}>::Create("tan"),
378314
FolderFactory<F, F{std::tanh}>::Create("tanh"),
379315
};
380316
static constexpr HostRuntimeMap map{table};
381317
static_assert(map.Verify(), "map must be sorted");
382318
};
319+
#else
320+
// On AIX, call libm routines to preserve consistent value between
321+
// runtime and compile time evaluation.
322+
#ifdef __clang_major__
323+
#pragma clang diagnostic ignored "-Wc99-extensions"
324+
#endif
325+
326+
extern "C" {
327+
float _Complex cacosf(float _Complex);
328+
double _Complex cacos(double _Complex);
329+
float _Complex cacoshf(float _Complex);
330+
double _Complex cacosh(double _Complex);
331+
float _Complex casinf(float _Complex);
332+
double _Complex casin(double _Complex);
333+
float _Complex casinhf(float _Complex);
334+
double _Complex casinh(double _Complex);
335+
float _Complex catanf(float _Complex);
336+
double _Complex catan(double _Complex);
337+
float _Complex catanhf(float _Complex);
338+
double _Complex catanh(double _Complex);
339+
float _Complex ccosf(float _Complex);
340+
double _Complex ccos(double _Complex);
341+
float _Complex ccoshf(float _Complex);
342+
double _Complex ccosh(double _Complex);
343+
float _Complex cexpf(float _Complex);
344+
double _Complex cexp(double _Complex);
345+
float _Complex clogf(float _Complex);
346+
double _Complex __clog(double _Complex);
347+
float _Complex cpowf(float _Complex, float _Complex);
348+
double _Complex cpow(double _Complex, double _Complex);
349+
float _Complex csinf(float _Complex);
350+
double _Complex csin(double _Complex);
351+
float _Complex csinhf(float _Complex);
352+
double _Complex csinh(double _Complex);
353+
float _Complex csqrtf(float _Complex);
354+
double _Complex csqrt(double _Complex);
355+
float _Complex ctanf(float _Complex);
356+
double _Complex ctan(double _Complex);
357+
float _Complex ctanhf(float _Complex);
358+
double _Complex ctanh(double _Complex);
359+
}
360+
361+
template <typename T> struct ToStdComplex {
362+
using Type = T;
363+
using AType = Type;
364+
};
365+
template <> struct ToStdComplex<float _Complex> {
366+
using Type = std::complex<float>;
367+
using AType = const Type &;
368+
};
369+
template <> struct ToStdComplex<double _Complex> {
370+
using Type = std::complex<double>;
371+
using AType = const Type &;
372+
};
373+
374+
template <typename F, F func> struct CComplexFunc {};
375+
template <typename R, typename... A, FuncPointer<R, A...> func>
376+
struct CComplexFunc<FuncPointer<R, A...>, func> {
377+
static typename ToStdComplex<R>::Type wrapper(
378+
typename ToStdComplex<A>::AType... args) {
379+
R res{func(*reinterpret_cast<const A *>(&args)...)};
380+
return *reinterpret_cast<typename ToStdComplex<R>::Type *>(&res);
381+
}
382+
};
383+
#define C_COMPLEX_FUNC(func) CComplexFunc<decltype(&func), &func>::wrapper
384+
385+
template <>
386+
struct HostRuntimeLibrary<std::complex<float>, LibraryVersion::Libm> {
387+
COMPLEX_SIGNATURES(float)
388+
static constexpr HostRuntimeFunction table[]{
389+
FolderFactory<F, C_COMPLEX_FUNC(cacosf)>::Create("acos"),
390+
FolderFactory<F, C_COMPLEX_FUNC(cacoshf)>::Create("acosh"),
391+
FolderFactory<F, C_COMPLEX_FUNC(casinf)>::Create("asin"),
392+
FolderFactory<F, C_COMPLEX_FUNC(casinhf)>::Create("asinh"),
393+
FolderFactory<F, C_COMPLEX_FUNC(catanf)>::Create("atan"),
394+
FolderFactory<F, C_COMPLEX_FUNC(catanhf)>::Create("atanh"),
395+
FolderFactory<F, C_COMPLEX_FUNC(ccosf)>::Create("cos"),
396+
FolderFactory<F, C_COMPLEX_FUNC(ccoshf)>::Create("cosh"),
397+
FolderFactory<F, C_COMPLEX_FUNC(cexpf)>::Create("exp"),
398+
FolderFactory<F, C_COMPLEX_FUNC(clogf)>::Create("log"),
399+
FolderFactory<F2, C_COMPLEX_FUNC(cpowf)>::Create("pow"),
400+
FolderFactory<F, C_COMPLEX_FUNC(csinf)>::Create("sin"),
401+
FolderFactory<F, C_COMPLEX_FUNC(csinhf)>::Create("sinh"),
402+
FolderFactory<F, C_COMPLEX_FUNC(csqrtf)>::Create("sqrt"),
403+
FolderFactory<F, C_COMPLEX_FUNC(ctanf)>::Create("tan"),
404+
FolderFactory<F, C_COMPLEX_FUNC(ctanhf)>::Create("tanh"),
405+
};
406+
static constexpr HostRuntimeMap map{table};
407+
static_assert(map.Verify(), "map must be sorted");
408+
};
409+
template <>
410+
struct HostRuntimeLibrary<std::complex<double>, LibraryVersion::Libm> {
411+
COMPLEX_SIGNATURES(double)
412+
static constexpr HostRuntimeFunction table[]{
413+
FolderFactory<F, C_COMPLEX_FUNC(cacos)>::Create("acos"),
414+
FolderFactory<F, C_COMPLEX_FUNC(cacosh)>::Create("acosh"),
415+
FolderFactory<F, C_COMPLEX_FUNC(casin)>::Create("asin"),
416+
FolderFactory<F, C_COMPLEX_FUNC(casinh)>::Create("asinh"),
417+
FolderFactory<F, C_COMPLEX_FUNC(catan)>::Create("atan"),
418+
FolderFactory<F, C_COMPLEX_FUNC(catanh)>::Create("atanh"),
419+
FolderFactory<F, C_COMPLEX_FUNC(ccos)>::Create("cos"),
420+
FolderFactory<F, C_COMPLEX_FUNC(ccosh)>::Create("cosh"),
421+
FolderFactory<F, C_COMPLEX_FUNC(cexp)>::Create("exp"),
422+
FolderFactory<F, C_COMPLEX_FUNC(__clog)>::Create("log"),
423+
FolderFactory<F2, C_COMPLEX_FUNC(cpow)>::Create("pow"),
424+
FolderFactory<F, C_COMPLEX_FUNC(csin)>::Create("sin"),
425+
FolderFactory<F, C_COMPLEX_FUNC(csinh)>::Create("sinh"),
426+
FolderFactory<F, C_COMPLEX_FUNC(csqrt)>::Create("sqrt"),
427+
FolderFactory<F, C_COMPLEX_FUNC(ctan)>::Create("tan"),
428+
FolderFactory<F, C_COMPLEX_FUNC(ctanh)>::Create("tanh"),
429+
};
430+
static constexpr HostRuntimeMap map{table};
431+
static_assert(map.Verify(), "map must be sorted");
432+
};
433+
#endif // _AIX
434+
383435
// Note regarding cmath:
384436
// - cmath does not have modulo and erfc_scaled equivalent
385437
// - C++17 defined standard Bessel math functions std::cyl_bessel_j

0 commit comments

Comments
 (0)