Skip to content

[flang][runtime] Support SUM/PRODUCT/DOT_PRODUCT reductions for REAL(16). #83169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions flang/include/flang/Common/float128.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,20 @@
#endif /* (defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)) && \
!defined(_LIBCPP_VERSION) && !defined(__CUDA_ARCH__) */

/* Define pure C CFloat128Type and CFloat128ComplexType. */
#if LDBL_MANT_DIG == 113
typedef long double CFloat128Type;
typedef long double _Complex CFloat128ComplexType;
#elif HAS_FLOAT128
typedef __float128 CFloat128Type;
/*
* Use mode() attribute supported by GCC and Clang.
* Adjust it for other compilers as needed.
*/
#if !defined(_ARCH_PPC) || defined(__LONG_DOUBLE_IEEE128__)
typedef _Complex float __attribute__((mode(TC))) CFloat128ComplexType;
#else
typedef _Complex float __attribute__((mode(KC))) CFloat128ComplexType;
#endif
#endif
#endif /* FORTRAN_COMMON_FLOAT128_H_ */
12 changes: 9 additions & 3 deletions flang/include/flang/Runtime/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ void RTDECL(CppSumComplex8)(std::complex<double> &, const Descriptor &,
void RTDECL(CppSumComplex10)(std::complex<long double> &, const Descriptor &,
const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
void RTDECL(CppSumComplex16)(std::complex<long double> &, const Descriptor &,
const char *source, int line, int dim = 0,
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
void RTDECL(CppSumComplex16)(std::complex<CppFloat128Type> &,
const Descriptor &, const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
#endif

void RTDECL(SumDim)(Descriptor &result, const Descriptor &array, int dim,
const char *source, int line, const Descriptor *mask = nullptr);
Expand Down Expand Up @@ -145,12 +147,16 @@ void RTDECL(CppProductComplex4)(std::complex<float> &, const Descriptor &,
void RTDECL(CppProductComplex8)(std::complex<double> &, const Descriptor &,
const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
#if LDBL_MANT_DIG == 64
void RTDECL(CppProductComplex10)(std::complex<long double> &,
const Descriptor &, const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
void RTDECL(CppProductComplex16)(std::complex<long double> &,
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
void RTDECL(CppProductComplex16)(std::complex<CppFloat128Type> &,
const Descriptor &, const char *source, int line, int dim = 0,
const Descriptor *mask = nullptr);
#endif

void RTDECL(ProductDim)(Descriptor &result, const Descriptor &array, int dim,
const char *source, int line, const Descriptor *mask = nullptr);
Expand Down
6 changes: 2 additions & 4 deletions flang/runtime/Float128Math/cabs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ namespace Fortran::runtime {
extern "C" {

#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
// FIXME: the argument should be CppTypeFor<TypeCategory::Complex, 16>,
// and it should be translated into the underlying library's
// corresponding complex128 type.
CppTypeFor<TypeCategory::Real, 16> RTDEF(CAbsF128)(ComplexF128 x) {
// NOTE: Flang calls the runtime APIs using C _Complex ABI
CppTypeFor<TypeCategory::Real, 16> RTDEF(CAbsF128)(CFloat128ComplexType x) {
return CAbs<RTNAME(CAbsF128)>::invoke(x);
}
#endif
Expand Down
9 changes: 0 additions & 9 deletions flang/runtime/Float128Math/math-entries.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,6 @@ DEFINE_FALLBACK(Y0)
DEFINE_FALLBACK(Y1)
DEFINE_FALLBACK(Yn)

// Define ComplexF128 type that is compatible with
// the type of results/arguments of libquadmath.
// TODO: this may need more work for other libraries/compilers.
#if !defined(_ARCH_PPC) || defined(__LONG_DOUBLE_IEEE128__)
typedef _Complex float __attribute__((mode(TC))) ComplexF128;
#else
typedef _Complex float __attribute__((mode(KC))) ComplexF128;
#endif

#if HAS_LIBM
// Define wrapper callers for libm.
#include <ccomplex>
Expand Down
47 changes: 38 additions & 9 deletions flang/runtime/complex-reduction.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ struct CppComplexDouble {
struct CppComplexLongDouble {
long double r, i;
};
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
struct CppComplexFloat128 {
CFloat128Type r, i;
};
#endif

/* Not all environments define CMPLXF, CMPLX, CMPLXL. */

Expand Down Expand Up @@ -70,6 +75,27 @@ static long_double_Complex_t CMPLXL(long double r, long double i) {
#endif
#endif

#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
/*
* GCC 7.4.0 (currently minimum GCC version for llvm builds)
* supports __builtin_complex. For Clang, require >=12.0.
* Otherwise, rely on the memory layout compatibility.
*/
#if (defined(__clang_major__) && (__clang_major__ >= 12)) || defined(__GNUC__)
#define CMPLXF128 __builtin_complex
#else
static CFloat128ComplexType CMPLXF128(CFloat128Type r, CFloat128Type i) {
union {
struct CppComplexFloat128 x;
CFloat128ComplexType result;
} u;
u.x.r = r;
u.x.i = i;
return u.result;
}
#endif
#endif

/* RTNAME(SumComplex4) calls RTNAME(CppSumComplex4) with the same arguments
* and converts the members of its C++ complex result to C _Complex.
*/
Expand All @@ -93,9 +119,10 @@ ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX,
#if LDBL_MANT_DIG == 64
ADAPT_REDUCTION(SumComplex10, long_double_Complex_t, CppComplexLongDouble,
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#elif LDBL_MANT_DIG == 113
ADAPT_REDUCTION(SumComplex16, long_double_Complex_t, CppComplexLongDouble,
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
ADAPT_REDUCTION(SumComplex16, CFloat128ComplexType, CppComplexFloat128,
CMPLXF128, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#endif

/* PRODUCT() */
Expand All @@ -106,9 +133,10 @@ ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
#if LDBL_MANT_DIG == 64
ADAPT_REDUCTION(ProductComplex10, long_double_Complex_t, CppComplexLongDouble,
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#elif LDBL_MANT_DIG == 113
ADAPT_REDUCTION(ProductComplex16, long_double_Complex_t, CppComplexLongDouble,
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
ADAPT_REDUCTION(ProductComplex16, CFloat128ComplexType, CppComplexFloat128,
CMPLXF128, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
#endif

/* DOT_PRODUCT() */
Expand All @@ -119,7 +147,8 @@ ADAPT_REDUCTION(DotProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
#if LDBL_MANT_DIG == 64
ADAPT_REDUCTION(DotProductComplex10, long_double_Complex_t,
CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
#elif LDBL_MANT_DIG == 113
ADAPT_REDUCTION(DotProductComplex16, long_double_Complex_t,
CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
ADAPT_REDUCTION(DotProductComplex16, CFloat128ComplexType, CppComplexFloat128,
CMPLXF128, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
#endif
13 changes: 10 additions & 3 deletions flang/runtime/complex-reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
#define FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_

#include "flang/Common/float128.h"
#include "flang/Runtime/entry-names.h"
#include <complex.h>

Expand All @@ -40,14 +41,18 @@ float_Complex_t RTNAME(SumComplex3)(REDUCTION_ARGS);
float_Complex_t RTNAME(SumComplex4)(REDUCTION_ARGS);
double_Complex_t RTNAME(SumComplex8)(REDUCTION_ARGS);
long_double_Complex_t RTNAME(SumComplex10)(REDUCTION_ARGS);
long_double_Complex_t RTNAME(SumComplex16)(REDUCTION_ARGS);
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
CFloat128ComplexType RTNAME(SumComplex16)(REDUCTION_ARGS);
#endif

float_Complex_t RTNAME(ProductComplex2)(REDUCTION_ARGS);
float_Complex_t RTNAME(ProductComplex3)(REDUCTION_ARGS);
float_Complex_t RTNAME(ProductComplex4)(REDUCTION_ARGS);
double_Complex_t RTNAME(ProductComplex8)(REDUCTION_ARGS);
long_double_Complex_t RTNAME(ProductComplex10)(REDUCTION_ARGS);
long_double_Complex_t RTNAME(ProductComplex16)(REDUCTION_ARGS);
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
CFloat128ComplexType RTNAME(ProductComplex16)(REDUCTION_ARGS);
#endif

#define DOT_PRODUCT_ARGS \
const struct CppDescriptor *x, const struct CppDescriptor *y, \
Expand All @@ -60,6 +65,8 @@ float_Complex_t RTNAME(DotProductComplex3)(DOT_PRODUCT_ARGS);
float_Complex_t RTNAME(DotProductComplex4)(DOT_PRODUCT_ARGS);
double_Complex_t RTNAME(DotProductComplex8)(DOT_PRODUCT_ARGS);
long_double_Complex_t RTNAME(DotProductComplex10)(DOT_PRODUCT_ARGS);
long_double_Complex_t RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
CFloat128ComplexType RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
#endif

#endif // FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
6 changes: 4 additions & 2 deletions flang/runtime/product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ CppTypeFor<TypeCategory::Real, 10> RTDEF(ProductReal10)(const Descriptor &x,
NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
"PRODUCT");
}
#elif LDBL_MANT_DIG == 113
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
CppTypeFor<TypeCategory::Real, 16> RTDEF(ProductReal16)(const Descriptor &x,
const char *source, int line, int dim, const Descriptor *mask) {
return GetTotalReduction<TypeCategory::Real, 16>(x, source, line, dim, mask,
Expand Down Expand Up @@ -154,7 +155,8 @@ void RTDEF(CppProductComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
"PRODUCT");
}
#elif LDBL_MANT_DIG == 113
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
const Descriptor &x, const char *source, int line, int dim,
const Descriptor *mask) {
Expand Down
3 changes: 2 additions & 1 deletion flang/runtime/sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ void RTDEF(CppSumComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
result = GetTotalReduction<TypeCategory::Complex, 10>(
x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
}
#elif LDBL_MANT_DIG == 113
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
const Descriptor &x, const char *source, int line, int dim,
const Descriptor *mask) {
Expand Down