Skip to content

[flang][runtime] Added Fortran::common::reference_wrapper for use on device. #85178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

vzakhari
Copy link
Contributor

This is a simplified implementation of std::reference_wrapper that can be used
in the offload builds for the device code. The methods are properly
marked with RT_API_ATTRS so that the device compilation succedes.

Created using spr 1.3.4
@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category labels Mar 14, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 14, 2024

@llvm/pr-subscribers-flang-runtime

Author: Slava Zakharin (vzakhari)

Changes

This is a simplified implementation of std::reference_wrapper that can be used
in the offload builds for the device code. The methods are properly
marked with RT_API_ATTRS so that the device compilation succedes.


Full diff: https://github.com/llvm/llvm-project/pull/85178.diff

2 Files Affected:

  • (added) flang/include/flang/Common/reference-wrapper.h (+114)
  • (modified) flang/runtime/io-stmt.h (+34-25)
diff --git a/flang/include/flang/Common/reference-wrapper.h b/flang/include/flang/Common/reference-wrapper.h
new file mode 100644
index 00000000000000..66f924662d9612
--- /dev/null
+++ b/flang/include/flang/Common/reference-wrapper.h
@@ -0,0 +1,114 @@
+//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// clang-format off
+//
+// Implementation of std::reference_wrapper borrowed from libcu++
+// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1
+// with modifications.
+//
+// The original source code is distributed under the Apache License v2.0
+// with LLVM Exceptions.
+//
+// TODO: using libcu++ is the best option for CUDA, but there is a couple
+// of issues:
+//   * The include paths need to be set up such that all STD header files
+//     are taken from libcu++.
+//   * cuda:: namespace need to be forced for all std:: references.
+//
+// clang-format on
+
+#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H
+#define FORTRAN_COMMON_REFERENCE_WRAPPER_H
+
+#include "flang/Runtime/api-attrs.h"
+#include <functional>
+#include <type_traits>
+
+#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \
+    (defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__)
+#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1
+#endif
+
+namespace Fortran::common {
+
+template <class _Tp>
+using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>;
+template <class _Tp, class _Up>
+struct __is_same_uncvref
+    : std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {};
+
+#if STD_REFERENCE_WRAPPER_UNSUPPORTED
+template <class _Tp> class reference_wrapper {
+public:
+  // types
+  typedef _Tp type;
+
+private:
+  type *__f_;
+
+  static RT_API_ATTRS void __fun(_Tp &);
+  static void __fun(_Tp &&) = delete;
+
+public:
+  template <class _Up,
+      class =
+          std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value,
+              decltype(__fun(std::declval<_Up>()))>>
+  constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) {
+    type &__f = static_cast<_Up &&>(__u);
+    __f_ = std::addressof(__f);
+  }
+
+  // access
+  constexpr RT_API_ATTRS operator type &() const { return *__f_; }
+  constexpr RT_API_ATTRS type &get() const { return *__f_; }
+
+  // invoke
+  template <class... _ArgTypes>
+  constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...>
+  operator()(_ArgTypes &&...__args) const {
+    return std::invoke(get(), std::forward<_ArgTypes>(__args)...);
+  }
+};
+
+template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>;
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) {
+  return reference_wrapper<_Tp>(__t);
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(
+    reference_wrapper<_Tp> __t) {
+  return __t;
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
+    const _Tp &__t) {
+  return reference_wrapper<const _Tp>(__t);
+}
+
+template <class _Tp>
+inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
+    reference_wrapper<_Tp> __t) {
+  return __t;
+}
+
+template <class _Tp> void ref(const _Tp &&) = delete;
+template <class _Tp> void cref(const _Tp &&) = delete;
+#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED
+using std::cref;
+using std::ref;
+using std::reference_wrapper;
+#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED
+
+} // namespace Fortran::common
+
+#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H
diff --git a/flang/runtime/io-stmt.h b/flang/runtime/io-stmt.h
index 0477c32b3b53ad..e00d54980aae59 100644
--- a/flang/runtime/io-stmt.h
+++ b/flang/runtime/io-stmt.h
@@ -17,6 +17,7 @@
 #include "internal-unit.h"
 #include "io-error.h"
 #include "flang/Common/optional.h"
+#include "flang/Common/reference-wrapper.h"
 #include "flang/Common/visit.h"
 #include "flang/Runtime/descriptor.h"
 #include "flang/Runtime/io-api.h"
@@ -210,39 +211,47 @@ class IoStatementState {
   }
 
 private:
-  std::variant<std::reference_wrapper<OpenStatementState>,
-      std::reference_wrapper<CloseStatementState>,
-      std::reference_wrapper<NoopStatementState>,
-      std::reference_wrapper<
+  std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
+      Fortran::common::reference_wrapper<CloseStatementState>,
+      Fortran::common::reference_wrapper<NoopStatementState>,
+      Fortran::common::reference_wrapper<
           InternalFormattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
           InternalFormattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<InternalListIoStatementState<Direction::Output>>,
-      std::reference_wrapper<InternalListIoStatementState<Direction::Input>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
+          InternalListIoStatementState<Direction::Output>>,
+      Fortran::common::reference_wrapper<
+          InternalListIoStatementState<Direction::Input>>,
+      Fortran::common::reference_wrapper<
           ExternalFormattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
           ExternalFormattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<ExternalListIoStatementState<Direction::Output>>,
-      std::reference_wrapper<ExternalListIoStatementState<Direction::Input>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
+          ExternalListIoStatementState<Direction::Output>>,
+      Fortran::common::reference_wrapper<
+          ExternalListIoStatementState<Direction::Input>>,
+      Fortran::common::reference_wrapper<
           ExternalUnformattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
           ExternalUnformattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<ChildFormattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<ChildFormattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<ChildListIoStatementState<Direction::Output>>,
-      std::reference_wrapper<ChildListIoStatementState<Direction::Input>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
+          ChildFormattedIoStatementState<Direction::Output>>,
+      Fortran::common::reference_wrapper<
+          ChildFormattedIoStatementState<Direction::Input>>,
+      Fortran::common::reference_wrapper<
+          ChildListIoStatementState<Direction::Output>>,
+      Fortran::common::reference_wrapper<
+          ChildListIoStatementState<Direction::Input>>,
+      Fortran::common::reference_wrapper<
           ChildUnformattedIoStatementState<Direction::Output>>,
-      std::reference_wrapper<
+      Fortran::common::reference_wrapper<
           ChildUnformattedIoStatementState<Direction::Input>>,
-      std::reference_wrapper<InquireUnitState>,
-      std::reference_wrapper<InquireNoUnitState>,
-      std::reference_wrapper<InquireUnconnectedFileState>,
-      std::reference_wrapper<InquireIOLengthState>,
-      std::reference_wrapper<ExternalMiscIoStatementState>,
-      std::reference_wrapper<ErroneousIoStatementState>>
+      Fortran::common::reference_wrapper<InquireUnitState>,
+      Fortran::common::reference_wrapper<InquireNoUnitState>,
+      Fortran::common::reference_wrapper<InquireUnconnectedFileState>,
+      Fortran::common::reference_wrapper<InquireIOLengthState>,
+      Fortran::common::reference_wrapper<ExternalMiscIoStatementState>,
+      Fortran::common::reference_wrapper<ErroneousIoStatementState>>
       u_;
 };
 

@vzakhari vzakhari requested review from klausler and jeanPerier March 14, 2024 05:33
Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

vzakhari and others added 4 commits March 15, 2024 14:27
Created using spr 1.3.4

[skip ci]
Created using spr 1.3.4
Created using spr 1.3.4

[skip ci]
Created using spr 1.3.4
@vzakhari vzakhari changed the base branch from users/vzakhari/spr/main.flangruntime-added-fortrancommonreference_wrapper-for-use-on-device to main March 15, 2024 21:41
@vzakhari vzakhari merged commit d8f97c0 into main Mar 15, 2024
@vzakhari vzakhari deleted the users/vzakhari/spr/flangruntime-added-fortrancommonreference_wrapper-for-use-on-device branch March 15, 2024 21:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants