Skip to content

[flang][runtime] Added Fortran::common::reference_wrapper for use on device. #85178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions flang/include/flang/Common/reference-wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//===-- include/flang/Common/reference-wrapper.h ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// clang-format off
//
// Implementation of std::reference_wrapper borrowed from libcu++
// https://github.com/NVIDIA/libcudacxx/blob/f7e6cd07ed5ba826aeac0b742feafddfedc1e400/include/cuda/std/detail/libcxx/include/__functional/reference_wrapper.h#L1
// with modifications.
//
// The original source code is distributed under the Apache License v2.0
// with LLVM Exceptions.
//
// TODO: using libcu++ is the best option for CUDA, but there is a couple
// of issues:
// * The include paths need to be set up such that all STD header files
// are taken from libcu++.
// * cuda:: namespace need to be forced for all std:: references.
//
// clang-format on

#ifndef FORTRAN_COMMON_REFERENCE_WRAPPER_H
#define FORTRAN_COMMON_REFERENCE_WRAPPER_H

#include "flang/Runtime/api-attrs.h"
#include <functional>
#include <type_traits>

#if !defined(STD_REFERENCE_WRAPPER_UNSUPPORTED) && \
(defined(__CUDACC__) || defined(__CUDA__)) && defined(__CUDA_ARCH__)
#define STD_REFERENCE_WRAPPER_UNSUPPORTED 1
#endif

namespace Fortran::common {

template <class _Tp>
using __remove_cvref_t = std::remove_cv_t<std::remove_reference_t<_Tp>>;
template <class _Tp, class _Up>
struct __is_same_uncvref
: std::is_same<__remove_cvref_t<_Tp>, __remove_cvref_t<_Up>> {};

#if STD_REFERENCE_WRAPPER_UNSUPPORTED
template <class _Tp> class reference_wrapper {
public:
// types
typedef _Tp type;

private:
type *__f_;

static RT_API_ATTRS void __fun(_Tp &);
static void __fun(_Tp &&) = delete;

public:
template <class _Up,
class =
std::enable_if_t<!__is_same_uncvref<_Up, reference_wrapper>::value,
decltype(__fun(std::declval<_Up>()))>>
constexpr RT_API_ATTRS reference_wrapper(_Up &&__u) {
type &__f = static_cast<_Up &&>(__u);
__f_ = std::addressof(__f);
}

// access
constexpr RT_API_ATTRS operator type &() const { return *__f_; }
constexpr RT_API_ATTRS type &get() const { return *__f_; }

// invoke
template <class... _ArgTypes>
constexpr RT_API_ATTRS typename std::invoke_result_t<type &, _ArgTypes...>
operator()(_ArgTypes &&...__args) const {
return std::invoke(get(), std::forward<_ArgTypes>(__args)...);
}
};

template <class _Tp> reference_wrapper(_Tp &) -> reference_wrapper<_Tp>;

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(_Tp &__t) {
return reference_wrapper<_Tp>(__t);
}

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<_Tp> ref(
reference_wrapper<_Tp> __t) {
return __t;
}

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
const _Tp &__t) {
return reference_wrapper<const _Tp>(__t);
}

template <class _Tp>
inline constexpr RT_API_ATTRS reference_wrapper<const _Tp> cref(
reference_wrapper<_Tp> __t) {
return __t;
}

template <class _Tp> void ref(const _Tp &&) = delete;
template <class _Tp> void cref(const _Tp &&) = delete;
#else // !STD_REFERENCE_WRAPPER_UNSUPPORTED
using std::cref;
using std::ref;
using std::reference_wrapper;
#endif // !STD_REFERENCE_WRAPPER_UNSUPPORTED

} // namespace Fortran::common

#endif // FORTRAN_COMMON_REFERENCE_WRAPPER_H
59 changes: 34 additions & 25 deletions flang/runtime/io-stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "internal-unit.h"
#include "io-error.h"
#include "flang/Common/optional.h"
#include "flang/Common/reference-wrapper.h"
#include "flang/Common/visit.h"
#include "flang/Runtime/descriptor.h"
#include "flang/Runtime/io-api.h"
Expand Down Expand Up @@ -210,39 +211,47 @@ class IoStatementState {
}

private:
std::variant<std::reference_wrapper<OpenStatementState>,
std::reference_wrapper<CloseStatementState>,
std::reference_wrapper<NoopStatementState>,
std::reference_wrapper<
std::variant<Fortran::common::reference_wrapper<OpenStatementState>,
Fortran::common::reference_wrapper<CloseStatementState>,
Fortran::common::reference_wrapper<NoopStatementState>,
Fortran::common::reference_wrapper<
InternalFormattedIoStatementState<Direction::Output>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
InternalFormattedIoStatementState<Direction::Input>>,
std::reference_wrapper<InternalListIoStatementState<Direction::Output>>,
std::reference_wrapper<InternalListIoStatementState<Direction::Input>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
InternalListIoStatementState<Direction::Output>>,
Fortran::common::reference_wrapper<
InternalListIoStatementState<Direction::Input>>,
Fortran::common::reference_wrapper<
ExternalFormattedIoStatementState<Direction::Output>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ExternalFormattedIoStatementState<Direction::Input>>,
std::reference_wrapper<ExternalListIoStatementState<Direction::Output>>,
std::reference_wrapper<ExternalListIoStatementState<Direction::Input>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ExternalListIoStatementState<Direction::Output>>,
Fortran::common::reference_wrapper<
ExternalListIoStatementState<Direction::Input>>,
Fortran::common::reference_wrapper<
ExternalUnformattedIoStatementState<Direction::Output>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ExternalUnformattedIoStatementState<Direction::Input>>,
std::reference_wrapper<ChildFormattedIoStatementState<Direction::Output>>,
std::reference_wrapper<ChildFormattedIoStatementState<Direction::Input>>,
std::reference_wrapper<ChildListIoStatementState<Direction::Output>>,
std::reference_wrapper<ChildListIoStatementState<Direction::Input>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ChildFormattedIoStatementState<Direction::Output>>,
Fortran::common::reference_wrapper<
ChildFormattedIoStatementState<Direction::Input>>,
Fortran::common::reference_wrapper<
ChildListIoStatementState<Direction::Output>>,
Fortran::common::reference_wrapper<
ChildListIoStatementState<Direction::Input>>,
Fortran::common::reference_wrapper<
ChildUnformattedIoStatementState<Direction::Output>>,
std::reference_wrapper<
Fortran::common::reference_wrapper<
ChildUnformattedIoStatementState<Direction::Input>>,
std::reference_wrapper<InquireUnitState>,
std::reference_wrapper<InquireNoUnitState>,
std::reference_wrapper<InquireUnconnectedFileState>,
std::reference_wrapper<InquireIOLengthState>,
std::reference_wrapper<ExternalMiscIoStatementState>,
std::reference_wrapper<ErroneousIoStatementState>>
Fortran::common::reference_wrapper<InquireUnitState>,
Fortran::common::reference_wrapper<InquireNoUnitState>,
Fortran::common::reference_wrapper<InquireUnconnectedFileState>,
Fortran::common::reference_wrapper<InquireIOLengthState>,
Fortran::common::reference_wrapper<ExternalMiscIoStatementState>,
Fortran::common::reference_wrapper<ErroneousIoStatementState>>
u_;
};

Expand Down