Skip to content

Commit 070984e

Browse files
vzakharitmsri
authored andcommitted
[flang][runtime] Use cuda::std::complex in F18 runtime CUDA build. (llvm#109078)
`std::complex` operators do not work for the CUDA device compilation of F18 runtime. This change makes use of `cuda::std::complex` from `libcudacxx`. `cuda::std::complex` does not have specializations for `long double`, so the change is accompanied with a clean-up for `long double` usage.
1 parent 4869afa commit 070984e

23 files changed

+480
-381
lines changed

flang/include/flang/Common/float80.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*===-- flang/Common/float80.h --------------------------------------*- C -*-===
2+
*
3+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
* See https://llvm.org/LICENSE.txt for license information.
5+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
*
7+
*===----------------------------------------------------------------------===*/
8+
9+
/* This header is usable in both C and C++ code.
10+
* Isolates build compiler checks to determine if the 80-bit
11+
* floating point format is supported via a particular C type.
12+
* It defines CFloat80Type and CppFloat80Type aliases for this
13+
* C type.
14+
*/
15+
16+
#ifndef FORTRAN_COMMON_FLOAT80_H_
17+
#define FORTRAN_COMMON_FLOAT80_H_
18+
19+
#include "api-attrs.h"
20+
#include <float.h>
21+
22+
#if LDBL_MANT_DIG == 64
23+
#undef HAS_FLOAT80
24+
#define HAS_FLOAT80 1
25+
#endif
26+
27+
#if defined(RT_DEVICE_COMPILATION) && defined(__CUDACC__)
28+
/*
29+
* 'long double' is treated as 'double' in the CUDA device code,
30+
* and there is no support for 80-bit floating point format.
31+
* This is probably true for most offload devices, so RT_DEVICE_COMPILATION
32+
* check should be enough. For the time being, guard it with __CUDACC__
33+
* as well.
34+
*/
35+
#undef HAS_FLOAT80
36+
#endif
37+
38+
#if HAS_FLOAT80
39+
typedef long double CFloat80Type;
40+
typedef long double CppFloat80Type;
41+
#endif
42+
43+
#endif /* FORTRAN_COMMON_FLOAT80_H_ */

flang/include/flang/Runtime/complex.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===-- include/flang/Runtime/complex.h -------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// A single way to expose C++ complex class in files that can be used
10+
// in F18 runtime build. With inclusion of this file std::complex
11+
// and the related names become available, though, they may correspond
12+
// to alternative definitions (e.g. from cuda::std namespace).
13+
14+
#ifndef FORTRAN_RUNTIME_COMPLEX_H
15+
#define FORTRAN_RUNTIME_COMPLEX_H
16+
17+
#if RT_USE_LIBCUDACXX
18+
#include <cuda/std/complex>
19+
namespace Fortran::runtime::rtcmplx {
20+
using cuda::std::complex;
21+
using cuda::std::conj;
22+
} // namespace Fortran::runtime::rtcmplx
23+
#else // !RT_USE_LIBCUDACXX
24+
#include <complex>
25+
namespace Fortran::runtime::rtcmplx {
26+
using std::complex;
27+
using std::conj;
28+
} // namespace Fortran::runtime::rtcmplx
29+
#endif // !RT_USE_LIBCUDACXX
30+
31+
#endif // FORTRAN_RUNTIME_COMPLEX_H

flang/include/flang/Runtime/cpp-type.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313

1414
#include "flang/Common/Fortran.h"
1515
#include "flang/Common/float128.h"
16+
#include "flang/Common/float80.h"
1617
#include "flang/Common/uint128.h"
17-
#include <complex>
18+
#include "flang/Runtime/complex.h"
1819
#include <cstdint>
1920
#if __cplusplus >= 202302
2021
#include <stdfloat>
@@ -70,9 +71,9 @@ template <> struct CppTypeForHelper<TypeCategory::Real, 8> {
7071
using type = double;
7172
#endif
7273
};
73-
#if LDBL_MANT_DIG == 64
74+
#if HAS_FLOAT80
7475
template <> struct CppTypeForHelper<TypeCategory::Real, 10> {
75-
using type = long double;
76+
using type = CppFloat80Type;
7677
};
7778
#endif
7879
#if __STDCPP_FLOAT128_T__
@@ -89,7 +90,7 @@ template <> struct CppTypeForHelper<TypeCategory::Real, 16> {
8990
#endif
9091

9192
template <int KIND> struct CppTypeForHelper<TypeCategory::Complex, KIND> {
92-
using type = std::complex<CppTypeFor<TypeCategory::Real, KIND>>;
93+
using type = rtcmplx::complex<CppTypeFor<TypeCategory::Real, KIND>>;
9394
};
9495

9596
template <> struct CppTypeForHelper<TypeCategory::Character, 1> {

flang/include/flang/Runtime/matmul-instances.inc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
111111
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
112112
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
113113

114-
#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
114+
#if MATMUL_FORCE_ALL_TYPES || HAS_FLOAT80
115115
MATMUL_INSTANCE(Integer, 16, Real, 10)
116116
MATMUL_INSTANCE(Integer, 16, Complex, 10)
117117
MATMUL_INSTANCE(Real, 10, Integer, 16)
@@ -133,7 +133,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
133133
#endif
134134
#endif // MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
135135

136-
#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
136+
#if MATMUL_FORCE_ALL_TYPES || HAS_FLOAT80
137137
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro) \
138138
macro(Integer, 1, Real, 10) \
139139
macro(Integer, 1, Complex, 10) \
@@ -193,7 +193,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
193193
MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
194194
MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
195195
#endif
196-
#endif // MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
196+
#endif // MATMUL_FORCE_ALL_TYPES || HAS_FLOAT80
197197

198198
#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
199199
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro) \

flang/include/flang/Runtime/numeric.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ CppTypeFor<TypeCategory::Integer, 8> RTDECL(Ceiling8_8)(
4444
CppTypeFor<TypeCategory::Integer, 16> RTDECL(Ceiling8_16)(
4545
CppTypeFor<TypeCategory::Real, 8>);
4646
#endif
47-
#if LDBL_MANT_DIG == 64
47+
#if HAS_FLOAT80
4848
CppTypeFor<TypeCategory::Integer, 1> RTDECL(Ceiling10_1)(
4949
CppTypeFor<TypeCategory::Real, 10>);
5050
CppTypeFor<TypeCategory::Integer, 2> RTDECL(Ceiling10_2)(
@@ -78,7 +78,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(ErfcScaled4)(
7878
CppTypeFor<TypeCategory::Real, 4>);
7979
CppTypeFor<TypeCategory::Real, 8> RTDECL(ErfcScaled8)(
8080
CppTypeFor<TypeCategory::Real, 8>);
81-
#if LDBL_MANT_DIG == 64
81+
#if HAS_FLOAT80
8282
CppTypeFor<TypeCategory::Real, 10> RTDECL(ErfcScaled10)(
8383
CppTypeFor<TypeCategory::Real, 10>);
8484
#endif
@@ -96,7 +96,7 @@ CppTypeFor<TypeCategory::Integer, 4> RTDECL(Exponent8_4)(
9696
CppTypeFor<TypeCategory::Real, 8>);
9797
CppTypeFor<TypeCategory::Integer, 8> RTDECL(Exponent8_8)(
9898
CppTypeFor<TypeCategory::Real, 8>);
99-
#if LDBL_MANT_DIG == 64
99+
#if HAS_FLOAT80
100100
CppTypeFor<TypeCategory::Integer, 4> RTDECL(Exponent10_4)(
101101
CppTypeFor<TypeCategory::Real, 10>);
102102
CppTypeFor<TypeCategory::Integer, 8> RTDECL(Exponent10_8)(
@@ -134,7 +134,7 @@ CppTypeFor<TypeCategory::Integer, 8> RTDECL(Floor8_8)(
134134
CppTypeFor<TypeCategory::Integer, 16> RTDECL(Floor8_16)(
135135
CppTypeFor<TypeCategory::Real, 8>);
136136
#endif
137-
#if LDBL_MANT_DIG == 64
137+
#if HAS_FLOAT80
138138
CppTypeFor<TypeCategory::Integer, 1> RTDECL(Floor10_1)(
139139
CppTypeFor<TypeCategory::Real, 10>);
140140
CppTypeFor<TypeCategory::Integer, 2> RTDECL(Floor10_2)(
@@ -168,7 +168,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(Fraction4)(
168168
CppTypeFor<TypeCategory::Real, 4>);
169169
CppTypeFor<TypeCategory::Real, 8> RTDECL(Fraction8)(
170170
CppTypeFor<TypeCategory::Real, 8>);
171-
#if LDBL_MANT_DIG == 64
171+
#if HAS_FLOAT80
172172
CppTypeFor<TypeCategory::Real, 10> RTDECL(Fraction10)(
173173
CppTypeFor<TypeCategory::Real, 10>);
174174
#endif
@@ -180,7 +180,7 @@ CppTypeFor<TypeCategory::Real, 16> RTDECL(Fraction16)(
180180
// ISNAN / IEEE_IS_NAN
181181
bool RTDECL(IsNaN4)(CppTypeFor<TypeCategory::Real, 4>);
182182
bool RTDECL(IsNaN8)(CppTypeFor<TypeCategory::Real, 8>);
183-
#if LDBL_MANT_DIG == 64
183+
#if HAS_FLOAT80
184184
bool RTDECL(IsNaN10)(CppTypeFor<TypeCategory::Real, 10>);
185185
#endif
186186
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
@@ -212,7 +212,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(ModReal4)(
212212
CppTypeFor<TypeCategory::Real, 8> RTDECL(ModReal8)(
213213
CppTypeFor<TypeCategory::Real, 8>, CppTypeFor<TypeCategory::Real, 8>,
214214
const char *sourceFile = nullptr, int sourceLine = 0);
215-
#if LDBL_MANT_DIG == 64
215+
#if HAS_FLOAT80
216216
CppTypeFor<TypeCategory::Real, 10> RTDECL(ModReal10)(
217217
CppTypeFor<TypeCategory::Real, 10>, CppTypeFor<TypeCategory::Real, 10>,
218218
const char *sourceFile = nullptr, int sourceLine = 0);
@@ -247,7 +247,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(ModuloReal4)(
247247
CppTypeFor<TypeCategory::Real, 8> RTDECL(ModuloReal8)(
248248
CppTypeFor<TypeCategory::Real, 8>, CppTypeFor<TypeCategory::Real, 8>,
249249
const char *sourceFile = nullptr, int sourceLine = 0);
250-
#if LDBL_MANT_DIG == 64
250+
#if HAS_FLOAT80
251251
CppTypeFor<TypeCategory::Real, 10> RTDECL(ModuloReal10)(
252252
CppTypeFor<TypeCategory::Real, 10>, CppTypeFor<TypeCategory::Real, 10>,
253253
const char *sourceFile = nullptr, int sourceLine = 0);
@@ -283,7 +283,7 @@ CppTypeFor<TypeCategory::Integer, 8> RTDECL(Nint8_8)(
283283
CppTypeFor<TypeCategory::Integer, 16> RTDECL(Nint8_16)(
284284
CppTypeFor<TypeCategory::Real, 8>);
285285
#endif
286-
#if LDBL_MANT_DIG == 64
286+
#if HAS_FLOAT80
287287
CppTypeFor<TypeCategory::Integer, 1> RTDECL(Nint10_1)(
288288
CppTypeFor<TypeCategory::Real, 10>);
289289
CppTypeFor<TypeCategory::Integer, 2> RTDECL(Nint10_2)(
@@ -319,7 +319,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(Nearest4)(
319319
CppTypeFor<TypeCategory::Real, 4>, bool positive);
320320
CppTypeFor<TypeCategory::Real, 8> RTDECL(Nearest8)(
321321
CppTypeFor<TypeCategory::Real, 8>, bool positive);
322-
#if LDBL_MANT_DIG == 64
322+
#if HAS_FLOAT80
323323
CppTypeFor<TypeCategory::Real, 10> RTDECL(Nearest10)(
324324
CppTypeFor<TypeCategory::Real, 10>, bool positive);
325325
#endif
@@ -333,7 +333,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(RRSpacing4)(
333333
CppTypeFor<TypeCategory::Real, 4>);
334334
CppTypeFor<TypeCategory::Real, 8> RTDECL(RRSpacing8)(
335335
CppTypeFor<TypeCategory::Real, 8>);
336-
#if LDBL_MANT_DIG == 64
336+
#if HAS_FLOAT80
337337
CppTypeFor<TypeCategory::Real, 10> RTDECL(RRSpacing10)(
338338
CppTypeFor<TypeCategory::Real, 10>);
339339
#endif
@@ -347,7 +347,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(SetExponent4)(
347347
CppTypeFor<TypeCategory::Real, 4>, std::int64_t);
348348
CppTypeFor<TypeCategory::Real, 8> RTDECL(SetExponent8)(
349349
CppTypeFor<TypeCategory::Real, 8>, std::int64_t);
350-
#if LDBL_MANT_DIG == 64
350+
#if HAS_FLOAT80
351351
CppTypeFor<TypeCategory::Real, 10> RTDECL(SetExponent10)(
352352
CppTypeFor<TypeCategory::Real, 10>, std::int64_t);
353353
#endif
@@ -361,7 +361,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(Scale4)(
361361
CppTypeFor<TypeCategory::Real, 4>, std::int64_t);
362362
CppTypeFor<TypeCategory::Real, 8> RTDECL(Scale8)(
363363
CppTypeFor<TypeCategory::Real, 8>, std::int64_t);
364-
#if LDBL_MANT_DIG == 64
364+
#if HAS_FLOAT80
365365
CppTypeFor<TypeCategory::Real, 10> RTDECL(Scale10)(
366366
CppTypeFor<TypeCategory::Real, 10>, std::int64_t);
367367
#endif
@@ -410,7 +410,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(Spacing4)(
410410
CppTypeFor<TypeCategory::Real, 4>);
411411
CppTypeFor<TypeCategory::Real, 8> RTDECL(Spacing8)(
412412
CppTypeFor<TypeCategory::Real, 8>);
413-
#if LDBL_MANT_DIG == 64
413+
#if HAS_FLOAT80
414414
CppTypeFor<TypeCategory::Real, 10> RTDECL(Spacing10)(
415415
CppTypeFor<TypeCategory::Real, 10>);
416416
#endif
@@ -425,7 +425,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(FPow4i)(
425425
CppTypeFor<TypeCategory::Real, 8> RTDECL(FPow8i)(
426426
CppTypeFor<TypeCategory::Real, 8> b,
427427
CppTypeFor<TypeCategory::Integer, 4> e);
428-
#if LDBL_MANT_DIG == 64
428+
#if HAS_FLOAT80
429429
CppTypeFor<TypeCategory::Real, 10> RTDECL(FPow10i)(
430430
CppTypeFor<TypeCategory::Real, 10> b,
431431
CppTypeFor<TypeCategory::Integer, 4> e);
@@ -442,7 +442,7 @@ CppTypeFor<TypeCategory::Real, 4> RTDECL(FPow4k)(
442442
CppTypeFor<TypeCategory::Real, 8> RTDECL(FPow8k)(
443443
CppTypeFor<TypeCategory::Real, 8> b,
444444
CppTypeFor<TypeCategory::Integer, 8> e);
445-
#if LDBL_MANT_DIG == 64
445+
#if HAS_FLOAT80
446446
CppTypeFor<TypeCategory::Real, 10> RTDECL(FPow10k)(
447447
CppTypeFor<TypeCategory::Real, 10> b,
448448
CppTypeFor<TypeCategory::Integer, 8> e);

0 commit comments

Comments
 (0)