Skip to content

Commit d8f97c0

Browse files
authored
[flang][runtime] Added Fortran::common::reference_wrapper for use on device.
This is a simplified implementation of std::reference_wrapper that can be used in the offload builds for the device code. The methods are properly marked with RT_API_ATTRS so that the device compilation succedes. Reviewers: jeanPerier, klausler Reviewed By: jeanPerier Pull Request: #85178
1 parent 6e1959d commit d8f97c0

File tree

2 files changed

+148
-25
lines changed

2 files changed

+148
-25
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// clang-format off
9+
//
10+
// Implementation of std::reference_wrapper borrowed from libcu++
11+
// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1
12+
// with modifications.
13+
//
14+
// The original source code is distributed under the Apache License v2.0
15+
// with LLVM Exceptions.
16+
//
17+
// TODO: using libcu++ is the best option for CUDA, but there is a couple
18+
// of issues:
19+
// * The include paths need to be set up such that all STD header files
20+
// are taken from libcu++.
21+
// * cuda:: namespace need to be forced for all std:: references.
22+
//
23+
// clang-format on
24+
25+
#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H
26+
#define FORTRAN_COMMON_REFERENCE_WRAPPER_H
27+
28+
#include "flang/Runtime/api-attrs.h"
29+
#include <functional>
30+
#include <type_traits>
31+
32+
#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \
33+
(defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__)
34+
#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1
35+
#endif
36+
37+
namespace Fortran::common {
38+
39+
template <class _Tp>
40+
using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>;
41+
template <class _Tp, class _Up>
42+
struct __is_same_uncvref
43+
: std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {};
44+
45+
#if STD_REFERENCE_WRAPPER_UNSUPPORTED
46+
template <class _Tp> class reference_wrapper {
47+
public:
48+
// types
49+
typedef _Tp type;
50+
51+
private:
52+
type *__f_;
53+
54+
static RT_API_ATTRS void __fun(_Tp &);
55+
static void __fun(_Tp &&) = delete;
56+
57+
public:
58+
template <class _Up,
59+
class =
60+
std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value,
61+
decltype(__fun(std::declval<_Up>()))>>
62+
constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) {
63+
type &__f = static_cast<_Up &&>(__u);
64+
__f_ = std::addressof(__f);
65+
}
66+
67+
// access
68+
constexpr RT_API_ATTRS operator type &() const { return *__f_; }
69+
constexpr RT_API_ATTRS type &get() const { return *__f_; }
70+
71+
// invoke
72+
template <class... _ArgTypes>
73+
constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...>
74+
operator()(_ArgTypes &&...__args) const {
75+
return std::invoke(get(), std::forward<_ArgTypes>(__args)...);
76+
}
77+
};
78+
79+
template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>;
80+
81+
template <class _Tp>
82+
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) {
83+
return reference_wrapper<_Tp>(__t);
84+
}
85+
86+
template <class _Tp>
87+
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(
88+
reference_wrapper<_Tp> __t) {
89+
return __t;
90+
}
91+
92+
template <class _Tp>
93+
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
94+
const _Tp &__t) {
95+
return reference_wrapper<const _Tp>(__t);
96+
}
97+
98+
template <class _Tp>
99+
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
100+
reference_wrapper<_Tp> __t) {
101+
return __t;
102+
}
103+
104+
template <class _Tp> void ref(const _Tp &&) = delete;
105+
template <class _Tp> void cref(const _Tp &&) = delete;
106+
#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED
107+
using std::cref;
108+
using std::ref;
109+
using std::reference_wrapper;
110+
#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED
111+
112+
} // namespace Fortran::common
113+
114+
#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H

flang/runtime/io-stmt.h

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "internal-unit.h"
1818
#include "io-error.h"
1919
#include "flang/Common/optional.h"
20+
#include "flang/Common/reference-wrapper.h"
2021
#include "flang/Common/visit.h"
2122
#include "flang/Runtime/descriptor.h"
2223
#include "flang/Runtime/io-api.h"
@@ -210,39 +211,47 @@ class IoStatementState {
210211
}
211212

212213
private:
213-
std::variant<std::reference_wrapper<OpenStatementState>,
214-
std::reference_wrapper<CloseStatementState>,
215-
std::reference_wrapper<NoopStatementState>,
216-
std::reference_wrapper<
214+
std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
215+
Fortran::common::reference_wrapper<CloseStatementState>,
216+
Fortran::common::reference_wrapper<NoopStatementState>,
217+
Fortran::common::reference_wrapper<
217218
InternalFormattedIoStatementState<Direction::Output>>,
218-
std::reference_wrapper<
219+
Fortran::common::reference_wrapper<
219220
InternalFormattedIoStatementState<Direction::Input>>,
220-
std::reference_wrapper<InternalListIoStatementState<Direction::Output>>,
221-
std::reference_wrapper<InternalListIoStatementState<Direction::Input>>,
222-
std::reference_wrapper<
221+
Fortran::common::reference_wrapper<
222+
InternalListIoStatementState<Direction::Output>>,
223+
Fortran::common::reference_wrapper<
224+
InternalListIoStatementState<Direction::Input>>,
225+
Fortran::common::reference_wrapper<
223226
ExternalFormattedIoStatementState<Direction::Output>>,
224-
std::reference_wrapper<
227+
Fortran::common::reference_wrapper<
225228
ExternalFormattedIoStatementState<Direction::Input>>,
226-
std::reference_wrapper<ExternalListIoStatementState<Direction::Output>>,
227-
std::reference_wrapper<ExternalListIoStatementState<Direction::Input>>,
228-
std::reference_wrapper<
229+
Fortran::common::reference_wrapper<
230+
ExternalListIoStatementState<Direction::Output>>,
231+
Fortran::common::reference_wrapper<
232+
ExternalListIoStatementState<Direction::Input>>,
233+
Fortran::common::reference_wrapper<
229234
ExternalUnformattedIoStatementState<Direction::Output>>,
230-
std::reference_wrapper<
235+
Fortran::common::reference_wrapper<
231236
ExternalUnformattedIoStatementState<Direction::Input>>,
232-
std::reference_wrapper<ChildFormattedIoStatementState<Direction::Output>>,
233-
std::reference_wrapper<ChildFormattedIoStatementState<Direction::Input>>,
234-
std::reference_wrapper<ChildListIoStatementState<Direction::Output>>,
235-
std::reference_wrapper<ChildListIoStatementState<Direction::Input>>,
236-
std::reference_wrapper<
237+
Fortran::common::reference_wrapper<
238+
ChildFormattedIoStatementState<Direction::Output>>,
239+
Fortran::common::reference_wrapper<
240+
ChildFormattedIoStatementState<Direction::Input>>,
241+
Fortran::common::reference_wrapper<
242+
ChildListIoStatementState<Direction::Output>>,
243+
Fortran::common::reference_wrapper<
244+
ChildListIoStatementState<Direction::Input>>,
245+
Fortran::common::reference_wrapper<
237246
ChildUnformattedIoStatementState<Direction::Output>>,
238-
std::reference_wrapper<
247+
Fortran::common::reference_wrapper<
239248
ChildUnformattedIoStatementState<Direction::Input>>,
240-
std::reference_wrapper<InquireUnitState>,
241-
std::reference_wrapper<InquireNoUnitState>,
242-
std::reference_wrapper<InquireUnconnectedFileState>,
243-
std::reference_wrapper<InquireIOLengthState>,
244-
std::reference_wrapper<ExternalMiscIoStatementState>,
245-
std::reference_wrapper<ErroneousIoStatementState>>
249+
Fortran::common::reference_wrapper<InquireUnitState>,
250+
Fortran::common::reference_wrapper<InquireNoUnitState>,
251+
Fortran::common::reference_wrapper<InquireUnconnectedFileState>,
252+
Fortran::common::reference_wrapper<InquireIOLengthState>,
253+
Fortran::common::reference_wrapper<ExternalMiscIoStatementState>,
254+
Fortran::common::reference_wrapper<ErroneousIoStatementState>>
246255
u_;
247256
};
248257

0 commit comments

Comments
 (0)