Skip to content

[flang][cuda] Improve data transfer detection by filtering symbols #98378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,16 @@ extern template semantics::UnorderedSymbolSet CollectSymbols(
extern template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SubscriptInteger> &);

// Collects Symbols of interest for the CUDA data transfer in an expression
template <typename A>
semantics::UnorderedSymbolSet CollectCudaSymbols(const A &);
extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
const Expr<SomeType> &);
extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
const Expr<SomeInteger> &);
extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
const Expr<SubscriptInteger> &);

// Predicate: does a variable contain a vector-valued subscript (not a triplet)?
bool HasVectorSubscript(const Expr<SomeType> &);

Expand Down Expand Up @@ -1236,7 +1246,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
// Get the number of distinct symbols with CUDA attribute in the expression.
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
semantics::UnorderedSymbolSet symbols;
for (const Symbol &sym : CollectSymbols(expr)) {
for (const Symbol &sym : CollectCudaSymbols(expr)) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
Expand All @@ -1259,7 +1269,7 @@ template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
unsigned hostSymbols{0};
unsigned deviceSymbols{0};
for (const Symbol &sym : CollectSymbols(expr)) {
for (const Symbol &sym : CollectCudaSymbols(expr)) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
Expand Down
29 changes: 29 additions & 0 deletions flang/lib/Evaluate/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,35 @@ template semantics::UnorderedSymbolSet CollectSymbols(
template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SubscriptInteger> &);

struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
semantics::UnorderedSymbolSet> {
using Base =
SetTraverse<CollectCudaSymbolsHelper, semantics::UnorderedSymbolSet>;
CollectCudaSymbolsHelper() : Base{*this} {}
using Base::operator();
semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
return {symbol};
}
// Overload some of the operator() to filter out the symbols that are not
// of interest for CUDA data transfer logic.
semantics::UnorderedSymbolSet operator()(const Subscript &) const {
return {};
}
semantics::UnorderedSymbolSet operator()(const ProcedureRef &) const {
return {};
}
};
template <typename A>
semantics::UnorderedSymbolSet CollectCudaSymbols(const A &x) {
return CollectCudaSymbolsHelper{}(x);
}
template semantics::UnorderedSymbolSet CollectCudaSymbols(
const Expr<SomeType> &);
template semantics::UnorderedSymbolSet CollectCudaSymbols(
const Expr<SomeInteger> &);
template semantics::UnorderedSymbolSet CollectCudaSymbols(
const Expr<SubscriptInteger> &);

// HasVectorSubscript()
struct HasVectorSubscriptHelper
: public AnyTraverse<HasVectorSubscriptHelper, bool,
Expand Down
34 changes: 32 additions & 2 deletions flang/test/Lower/CUDA/cuda-data-transfer.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ module mod1
type :: t1
integer :: i
end type

contains
function dev1(a)
integer, device :: a(:)
integer :: dev1
end function
end

subroutine sub1()
Expand Down Expand Up @@ -213,11 +219,35 @@ subroutine sub10(a, b)
res = a + b
end subroutine



! CHECK-LABEL: func.func @_QPsub10(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}

! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub10Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: cuf.data_transfer %[[A]]#1 to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
! CHECK-NOT: cuf.data_transfer

subroutine sub11(n)
integer :: n
real, dimension(10) :: h
real, dimension(n), device :: d
do i=1,10
h(i) = d(i)
end do
end subroutine

! CHECK-LABEL: func.func @_QPsub11
! CHECK: %[[RHS:.*]] = hlfir.designate %{{.*}} (%{{.*}}) : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
! CHECK: %[[LHS:.*]] = hlfir.designate %{{.*}} (%{{.*}}) : (!fir.ref<!fir.array<10xf32>>, i64) -> !fir.ref<f32>
! CHECK: cuf.data_transfer %[[RHS]] to %[[LHS]] {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<f32>, !fir.ref<f32>

subroutine sub12()
use mod1
integer, device :: a(10)
integer :: x
x = dev1(a)
end subroutine

! CHECK-LABEL: func.func @_QPsub12
! CHECK: %{{.*}} = fir.call @_QMmod1Pdev1
! CHECK: hlfir.assign
! CHECK-NOT: cuf.data_transfer
Loading