Skip to content

[flang][runtime] Use cuda::std::complex in F18 runtime CUDA build. #109078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions flang/include/flang/Common/float80.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*===-- flang/Common/float80.h --------------------------------------*- C -*-===
*
* Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
* See https://llvm.org/LICENSE.txt for license information.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*
*===----------------------------------------------------------------------===*/

/* This header is usable in both C and C++ code.
* Isolates build compiler checks to determine if the 80-bit
* floating point format is supported via a particular C type.
* It defines CFloat80Type and CppFloat80Type aliases for this
* C type.
*/

#ifndef FORTRAN_COMMON_FLOAT80_H_
#define FORTRAN_COMMON_FLOAT80_H_

#include "api-attrs.h"
#include <float.h>

#if LDBL_MANT_DIG == 64
#undef HAS_FLOAT80
#define HAS_FLOAT80 1
#endif

#if defined(RT_DEVICE_COMPILATION) && defined(__CUDACC__)
/*
* 'long double' is treated as 'double' in the CUDA device code,
* and there is no support for 80-bit floating point format.
* This is probably true for most offload devices, so RT_DEVICE_COMPILATION
* check should be enough. For the time being, guard it with __CUDACC__
* as well.
*/
#undef HAS_FLOAT80
#endif

#if HAS_FLOAT80
typedef long double CFloat80Type;
typedef long double CppFloat80Type;
#endif

#endif /* FORTRAN_COMMON_FLOAT80_H_ */
31 changes: 31 additions & 0 deletions flang/include/flang/Runtime/complex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===-- include/flang/Runtime/complex.h -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// A single way to expose C++ complex class in files that can be used
// in F18 runtime build. With inclusion of this file std::complex
// and the related names become available, though, they may correspond
// to alternative definitions (e.g. from cuda::std namespace).

#ifndef FORTRAN_RUNTIME_COMPLEX_H
#define FORTRAN_RUNTIME_COMPLEX_H

#if RT_USE_LIBCUDACXX
#include <cuda/std/complex>
namespace Fortran::runtime::rtcmplx {
using cuda::std::complex;
using cuda::std::conj;
} // namespace Fortran::runtime::rtcmplx
#else // !RT_USE_LIBCUDACXX
#include <complex>
namespace Fortran::runtime::rtcmplx {
using std::complex;
using std::conj;
} // namespace Fortran::runtime::rtcmplx
#endif // !RT_USE_LIBCUDACXX

#endif // FORTRAN_RUNTIME_COMPLEX_H
9 changes: 5 additions & 4 deletions flang/include/flang/Runtime/cpp-type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

#include "flang/Common/Fortran.h"
#include "flang/Common/float128.h"
#include "flang/Common/float80.h"
#include "flang/Common/uint128.h"
#include <complex>
#include "flang/Runtime/complex.h"
#include <cstdint>
#if __cplusplus >= 202302
#include <stdfloat>
Expand Down Expand Up @@ -70,9 +71,9 @@ template <> struct CppTypeForHelper<TypeCategory::Real, 8> {
using type = double;
#endif
};
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
template <> struct CppTypeForHelper<TypeCategory::Real, 10> {
using type = long double;
using type = CppFloat80Type;
};
#endif
#if __STDCPP_FLOAT128_T__
Expand All @@ -89,7 +90,7 @@ template <> struct CppTypeForHelper<TypeCategory::Real, 16> {
#endif

template <int KIND> struct CppTypeForHelper<TypeCategory::Complex, KIND> {
using type = std::complex<CppTypeFor<TypeCategory::Real, KIND>>;
using type = rtcmplx::complex<CppTypeFor<TypeCategory::Real, KIND>>;
};

template <> struct CppTypeForHelper<TypeCategory::Character, 1> {
Expand Down
6 changes: 3 additions & 3 deletions flang/include/flang/Runtime/matmul-instances.inc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)

#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
#if MATMUL_FORCE_ALL_TYPES || HAS_FLOAT80
MATMUL_INSTANCE(Integer, 16, Real, 10)
MATMUL_INSTANCE(Integer, 16, Complex, 10)
MATMUL_INSTANCE(Real, 10, Integer, 16)
Expand All @@ -133,7 +133,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
#endif
#endif // MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)

#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
#if MATMUL_FORCE_ALL_TYPES || HAS_FLOAT80
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro) \
macro(Integer, 1, Real, 10) \
macro(Integer, 1, Complex, 10) \
Expand Down Expand Up @@ -193,7 +193,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
#endif
#endif // MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
#endif // MATMUL_FORCE_ALL_TYPES || HAS_FLOAT80

#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro) \
Expand Down
32 changes: 16 additions & 16 deletions flang/include/flang/Runtime/numeric.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ CppTypeFor<TypeCategory::Integer, 8> RTDECL(Ceiling8_8)(
CppTypeFor<TypeCategory::Integer, 16> RTDECL(Ceiling8_16)(
CppTypeFor<TypeCategory::Real, 8>);
#endif
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Integer, 1> RTDECL(Ceiling10_1)(
CppTypeFor<TypeCategory::Real, 10>);
CppTypeFor<TypeCategory::Integer, 2> RTDECL(Ceiling10_2)(
Expand Down Expand Up @@ -78,7 +78,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(ErfcScaled4)(
CppTypeFor<TypeCategory::Real, 4>);
CppTypeFor<TypeCategory::Real, 8> RTDECL(ErfcScaled8)(
CppTypeFor<TypeCategory::Real, 8>);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(ErfcScaled10)(
CppTypeFor<TypeCategory::Real, 10>);
#endif
Expand All @@ -96,7 +96,7 @@ CppTypeFor<TypeCategory::Integer, 4> RTDECL(Exponent8_4)(
CppTypeFor<TypeCategory::Real, 8>);
CppTypeFor<TypeCategory::Integer, 8> RTDECL(Exponent8_8)(
CppTypeFor<TypeCategory::Real, 8>);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Integer, 4> RTDECL(Exponent10_4)(
CppTypeFor<TypeCategory::Real, 10>);
CppTypeFor<TypeCategory::Integer, 8> RTDECL(Exponent10_8)(
Expand Down Expand Up @@ -134,7 +134,7 @@ CppTypeFor<TypeCategory::Integer, 8> RTDECL(Floor8_8)(
CppTypeFor<TypeCategory::Integer, 16> RTDECL(Floor8_16)(
CppTypeFor<TypeCategory::Real, 8>);
#endif
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Integer, 1> RTDECL(Floor10_1)(
CppTypeFor<TypeCategory::Real, 10>);
CppTypeFor<TypeCategory::Integer, 2> RTDECL(Floor10_2)(
Expand Down Expand Up @@ -168,7 +168,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(Fraction4)(
CppTypeFor<TypeCategory::Real, 4>);
CppTypeFor<TypeCategory::Real, 8> RTDECL(Fraction8)(
CppTypeFor<TypeCategory::Real, 8>);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(Fraction10)(
CppTypeFor<TypeCategory::Real, 10>);
#endif
Expand All @@ -180,7 +180,7 @@ CppTypeFor<TypeCategory::Real, 16> RTDECL(Fraction16)(
// ISNAN / IEEE_IS_NAN
bool RTDECL(IsNaN4)(CppTypeFor<TypeCategory::Real, 4>);
bool RTDECL(IsNaN8)(CppTypeFor<TypeCategory::Real, 8>);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
bool RTDECL(IsNaN10)(CppTypeFor<TypeCategory::Real, 10>);
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
Expand Down Expand Up @@ -212,7 +212,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(ModReal4)(
CppTypeFor<TypeCategory::Real, 8> RTDECL(ModReal8)(
CppTypeFor<TypeCategory::Real, 8>, CppTypeFor<TypeCategory::Real, 8>,
const char *sourceFile = nullptr, int sourceLine = 0);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(ModReal10)(
CppTypeFor<TypeCategory::Real, 10>, CppTypeFor<TypeCategory::Real, 10>,
const char *sourceFile = nullptr, int sourceLine = 0);
Expand Down Expand Up @@ -247,7 +247,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(ModuloReal4)(
CppTypeFor<TypeCategory::Real, 8> RTDECL(ModuloReal8)(
CppTypeFor<TypeCategory::Real, 8>, CppTypeFor<TypeCategory::Real, 8>,
const char *sourceFile = nullptr, int sourceLine = 0);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(ModuloReal10)(
CppTypeFor<TypeCategory::Real, 10>, CppTypeFor<TypeCategory::Real, 10>,
const char *sourceFile = nullptr, int sourceLine = 0);
Expand Down Expand Up @@ -283,7 +283,7 @@ CppTypeFor<TypeCategory::Integer, 8> RTDECL(Nint8_8)(
CppTypeFor<TypeCategory::Integer, 16> RTDECL(Nint8_16)(
CppTypeFor<TypeCategory::Real, 8>);
#endif
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Integer, 1> RTDECL(Nint10_1)(
CppTypeFor<TypeCategory::Real, 10>);
CppTypeFor<TypeCategory::Integer, 2> RTDECL(Nint10_2)(
Expand Down Expand Up @@ -319,7 +319,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(Nearest4)(
CppTypeFor<TypeCategory::Real, 4>, bool positive);
CppTypeFor<TypeCategory::Real, 8> RTDECL(Nearest8)(
CppTypeFor<TypeCategory::Real, 8>, bool positive);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(Nearest10)(
CppTypeFor<TypeCategory::Real, 10>, bool positive);
#endif
Expand All @@ -333,7 +333,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(RRSpacing4)(
CppTypeFor<TypeCategory::Real, 4>);
CppTypeFor<TypeCategory::Real, 8> RTDECL(RRSpacing8)(
CppTypeFor<TypeCategory::Real, 8>);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(RRSpacing10)(
CppTypeFor<TypeCategory::Real, 10>);
#endif
Expand All @@ -347,7 +347,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(SetExponent4)(
CppTypeFor<TypeCategory::Real, 4>, std::int64_t);
CppTypeFor<TypeCategory::Real, 8> RTDECL(SetExponent8)(
CppTypeFor<TypeCategory::Real, 8>, std::int64_t);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(SetExponent10)(
CppTypeFor<TypeCategory::Real, 10>, std::int64_t);
#endif
Expand All @@ -361,7 +361,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(Scale4)(
CppTypeFor<TypeCategory::Real, 4>, std::int64_t);
CppTypeFor<TypeCategory::Real, 8> RTDECL(Scale8)(
CppTypeFor<TypeCategory::Real, 8>, std::int64_t);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(Scale10)(
CppTypeFor<TypeCategory::Real, 10>, std::int64_t);
#endif
Expand Down Expand Up @@ -410,7 +410,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(Spacing4)(
CppTypeFor<TypeCategory::Real, 4>);
CppTypeFor<TypeCategory::Real, 8> RTDECL(Spacing8)(
CppTypeFor<TypeCategory::Real, 8>);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(Spacing10)(
CppTypeFor<TypeCategory::Real, 10>);
#endif
Expand All @@ -425,7 +425,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(FPow4i)(
CppTypeFor<TypeCategory::Real, 8> RTDECL(FPow8i)(
CppTypeFor<TypeCategory::Real, 8> b,
CppTypeFor<TypeCategory::Integer, 4> e);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(FPow10i)(
CppTypeFor<TypeCategory::Real, 10> b,
CppTypeFor<TypeCategory::Integer, 4> e);
Expand All @@ -442,7 +442,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(FPow4k)(
CppTypeFor<TypeCategory::Real, 8> RTDECL(FPow8k)(
CppTypeFor<TypeCategory::Real, 8> b,
CppTypeFor<TypeCategory::Integer, 8> e);
#if LDBL_MANT_DIG == 64
#if HAS_FLOAT80
CppTypeFor<TypeCategory::Real, 10> RTDECL(FPow10k)(
CppTypeFor<TypeCategory::Real, 10> b,
CppTypeFor<TypeCategory::Integer, 8> e);
Expand Down
Loading
Loading