-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][runtime] Added Fortran::common::reference_wrapper for use on device. #85178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][runtime] Added Fortran::common::reference_wrapper for use on device. #85178
Conversation
Created using spr 1.3.4 [skip ci]
Created using spr 1.3.4
@llvm/pr-subscribers-flang-runtime Author: Slava Zakharin (vzakhari) ChangesThis is a simplified implementation of std::reference_wrapper that can be used Full diff: https://github.com/llvm/llvm-project/pull/85178.diff 2 Files Affected:
diff --git a/flang/include/flang/Common/reference-wrapper.h b/flang/include/flang/Common/reference-wrapper.h
new file mode 100644
index 00000000000000..66f924662d9612
--- /dev/null
+++ b/flang/include/flang/Common/reference-wrapper.h
@@ -0,0 +1,114 @@
+//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// clang-format off
+//
+// Implementation of std::reference_wrapper borrowed from libcu++
+// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1
+// with modifications.
+//
+// The original source code is distributed under the Apache License v2.0
+// with LLVM Exceptions.
+//
+// TODO: using libcu++ is the best option for CUDA, but there is a couple
+// of issues:
+// * The include paths need to be set up such that all STD header files
+// are taken from libcu++.
+// * cuda:: namespace need to be forced for all std:: references.
+//
+// clang-format on
+
+#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H
+#define FORTRAN_COMMON_REFERENCE_WRAPPER_H
+
+#include "flang/Runtime/api-attrs.h"
+#include <functional>
+#include <type_traits>
+
+#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \
+ (defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__)
+#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1
+#endif
+
+namespace Fortran::common {
+
+template <class _Tp>
+using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>;
+template <class _Tp, class _Up>
+struct __is_same_uncvref
+ : std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {};
+
+#if STD_REFERENCE_WRAPPER_UNSUPPORTED
+template <class _Tp> class reference_wrapper {
+public:
+ // types
+ typedef _Tp type;
+
+private:
+ type *__f_;
+
+ static RT_API_ATTRS void __fun(_Tp &);
+ static void __fun(_Tp &&) = delete;
+
+public:
+ template <class _Up,
+ class =
+ std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value,
+ decltype(__fun(std::declval<_Up>()))>>
+ constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) {
+ type &__f = static_cast<_Up &&>(__u);
+ __f_ = std::addressof(__f);
+ }
+
+ // access
+ constexpr RT_API_ATTRS operator type &() const { return *__f_; }
+ constexpr RT_API_ATTRS type &get() const { return *__f_; }
+
+ // invoke
+ template <class... _ArgTypes>
+ constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...>
+ operator()(_ArgTypes &&...__args) const {
+ return std::invoke(get(), std::forward<_ArgTypes>(__args)...);
+ }
+};
+
+template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>;
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) {
+ return reference_wrapper<_Tp>(__t);
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(
+ reference_wrapper<_Tp> __t) {
+ return __t;
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
+ const _Tp &__t) {
+ return reference_wrapper<const _Tp>(__t);
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
+ reference_wrapper<_Tp> __t) {
+ return __t;
+}
+
+template <class _Tp> void ref(const _Tp &&) = delete;
+template <class _Tp> void cref(const _Tp &&) = delete;
+#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED
+using std::cref;
+using std::ref;
+using std::reference_wrapper;
+#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED
+
+} // namespace Fortran::common
+
+#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H
diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index 0477c32b3b53ad..e00d54980aae59 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -17,6 +17,7 @@
#include "internal-unit.h"
#include "io-error.h"
#include "flang/Common/optional.h"
+#include "flang/Common/reference-wrapper.h"
#include "flang/Common/visit.h"
#include "flang/Runtime/descriptor.h"
#include "flang/Runtime/io-api.h"
@@ -210,39 +211,47 @@ class IoStatementState {
}
private:
- std::variant<std::reference_wrapper<OpenStatementState>,
- std::reference_wrapper<CloseStatementState>,
- std::reference_wrapper<NoopStatementState>,
- std::reference_wrapper<
+ std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
+ Fortran::common::reference_wrapper<CloseStatementState>,
+ Fortran::common::reference_wrapper<NoopStatementState>,
+ Fortran::common::reference_wrapper<
InternalFormattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
InternalFormattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<InternalListIoStatementState<Direction::Output>>,
- std::reference_wrapper<InternalListIoStatementState<Direction::Input>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
+ InternalListIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ InternalListIoStatementState<Direction::Input>>,
+ Fortran::common::reference_wrapper<
ExternalFormattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
ExternalFormattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<ExternalListIoStatementState<Direction::Output>>,
- std::reference_wrapper<ExternalListIoStatementState<Direction::Input>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
+ ExternalListIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ ExternalListIoStatementState<Direction::Input>>,
+ Fortran::common::reference_wrapper<
ExternalUnformattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
ExternalUnformattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<ChildFormattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<ChildFormattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<ChildListIoStatementState<Direction::Output>>,
- std::reference_wrapper<ChildListIoStatementState<Direction::Input>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
+ ChildFormattedIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ ChildFormattedIoStatementState<Direction::Input>>,
+ Fortran::common::reference_wrapper<
+ ChildListIoStatementState<Direction::Output>>,
+ Fortran::common::reference_wrapper<
+ ChildListIoStatementState<Direction::Input>>,
+ Fortran::common::reference_wrapper<
ChildUnformattedIoStatementState<Direction::Output>>,
- std::reference_wrapper<
+ Fortran::common::reference_wrapper<
ChildUnformattedIoStatementState<Direction::Input>>,
- std::reference_wrapper<InquireUnitState>,
- std::reference_wrapper<InquireNoUnitState>,
- std::reference_wrapper<InquireUnconnectedFileState>,
- std::reference_wrapper<InquireIOLengthState>,
- std::reference_wrapper<ExternalMiscIoStatementState>,
- std::reference_wrapper<ErroneousIoStatementState>>
+ Fortran::common::reference_wrapper<InquireUnitState>,
+ Fortran::common::reference_wrapper<InquireNoUnitState>,
+ Fortran::common::reference_wrapper<InquireUnconnectedFileState>,
+ Fortran::common::reference_wrapper<InquireIOLengthState>,
+ Fortran::common::reference_wrapper<ExternalMiscIoStatementState>,
+ Fortran::common::reference_wrapper<ErroneousIoStatementState>>
u_;
};
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Created using spr 1.3.4 [skip ci]
Created using spr 1.3.4 [skip ci]
This is a simplified implementation of std::reference_wrapper that can be used
in the offload builds for the device code. The methods are properly
marked with RT_API_ATTRS so that the device compilation succedes.