Skip to content

Commit e66ea43

Browse files
authored
[flang][cuda] Improve data transfer detection by filtering symbols (#98378)
The current data transfer detection was collecting too many symbol and made wrong decision. This patch introduces a new function `CollectCudaSymbols` that is different than `CollectSymbols` and collect only symbol of interest for cuda data transfer in an expression. Currently two cases where symbols are filtered out are: - array subscripts: only the array symbol is on interest, the indexing can be filtered out - function arguments: symbols of the function arguments are filtered out. This fix some false positive data transfer and implicit data transfer. More filtering might be needed and will be added as follow up patches.
1 parent e16882f commit e66ea43

File tree

3 files changed

+73
-4
lines changed

3 files changed

+73
-4
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,16 @@ extern template semantics::UnorderedSymbolSet CollectSymbols(
10731073
extern template semantics::UnorderedSymbolSet CollectSymbols(
10741074
const Expr<SubscriptInteger> &);
10751075

1076+
// Collects Symbols of interest for the CUDA data transfer in an expression
1077+
template <typename A>
1078+
semantics::UnorderedSymbolSet CollectCudaSymbols(const A &);
1079+
extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
1080+
const Expr<SomeType> &);
1081+
extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
1082+
const Expr<SomeInteger> &);
1083+
extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
1084+
const Expr<SubscriptInteger> &);
1085+
10761086
// Predicate: does a variable contain a vector-valued subscript (not a triplet)?
10771087
bool HasVectorSubscript(const Expr<SomeType> &);
10781088

@@ -1236,7 +1246,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
12361246
// Get the number of distinct symbols with CUDA attribute in the expression.
12371247
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
12381248
semantics::UnorderedSymbolSet symbols;
1239-
for (const Symbol &sym : CollectSymbols(expr)) {
1249+
for (const Symbol &sym : CollectCudaSymbols(expr)) {
12401250
if (const auto *details =
12411251
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
12421252
if (details->cudaDataAttr() &&
@@ -1259,7 +1269,7 @@ template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
12591269
inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
12601270
unsigned hostSymbols{0};
12611271
unsigned deviceSymbols{0};
1262-
for (const Symbol &sym : CollectSymbols(expr)) {
1272+
for (const Symbol &sym : CollectCudaSymbols(expr)) {
12631273
if (const auto *details =
12641274
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
12651275
if (details->cudaDataAttr() &&

flang/lib/Evaluate/tools.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,35 @@ template semantics::UnorderedSymbolSet CollectSymbols(
10001000
template semantics::UnorderedSymbolSet CollectSymbols(
10011001
const Expr<SubscriptInteger> &);
10021002

1003+
struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
1004+
semantics::UnorderedSymbolSet> {
1005+
using Base =
1006+
SetTraverse<CollectCudaSymbolsHelper, semantics::UnorderedSymbolSet>;
1007+
CollectCudaSymbolsHelper() : Base{*this} {}
1008+
using Base::operator();
1009+
semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
1010+
return {symbol};
1011+
}
1012+
// Overload some of the operator() to filter out the symbols that are not
1013+
// of interest for CUDA data transfer logic.
1014+
semantics::UnorderedSymbolSet operator()(const Subscript &) const {
1015+
return {};
1016+
}
1017+
semantics::UnorderedSymbolSet operator()(const ProcedureRef &) const {
1018+
return {};
1019+
}
1020+
};
1021+
template <typename A>
1022+
semantics::UnorderedSymbolSet CollectCudaSymbols(const A &x) {
1023+
return CollectCudaSymbolsHelper{}(x);
1024+
}
1025+
template semantics::UnorderedSymbolSet CollectCudaSymbols(
1026+
const Expr<SomeType> &);
1027+
template semantics::UnorderedSymbolSet CollectCudaSymbols(
1028+
const Expr<SomeInteger> &);
1029+
template semantics::UnorderedSymbolSet CollectCudaSymbols(
1030+
const Expr<SubscriptInteger> &);
1031+
10031032
// HasVectorSubscript()
10041033
struct HasVectorSubscriptHelper
10051034
: public AnyTraverse<HasVectorSubscriptHelper, bool,

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ module mod1
66
type :: t1
77
integer :: i
88
end type
9+
10+
contains
11+
function dev1(a)
12+
integer, device :: a(:)
13+
integer :: dev1
14+
end function
915
end
1016

1117
subroutine sub1()
@@ -213,11 +219,35 @@ subroutine sub10(a, b)
213219
res = a + b
214220
end subroutine
215221

216-
217-
218222
! CHECK-LABEL: func.func @_QPsub10(
219223
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}
220224

221225
! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub10Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
222226
! CHECK: cuf.data_transfer %[[A]]#1 to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
223227
! CHECK-NOT: cuf.data_transfer
228+
229+
subroutine sub11(n)
230+
integer :: n
231+
real, dimension(10) :: h
232+
real, dimension(n), device :: d
233+
do i=1,10
234+
h(i) = d(i)
235+
end do
236+
end subroutine
237+
238+
! CHECK-LABEL: func.func @_QPsub11
239+
! CHECK: %[[RHS:.*]] = hlfir.designate %{{.*}} (%{{.*}}) : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
240+
! CHECK: %[[LHS:.*]] = hlfir.designate %{{.*}} (%{{.*}}) : (!fir.ref<!fir.array<10xf32>>, i64) -> !fir.ref<f32>
241+
! CHECK: cuf.data_transfer %[[RHS]] to %[[LHS]] {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<f32>, !fir.ref<f32>
242+
243+
subroutine sub12()
244+
use mod1
245+
integer, device :: a(10)
246+
integer :: x
247+
x = dev1(a)
248+
end subroutine
249+
250+
! CHECK-LABEL: func.func @_QPsub12
251+
! CHECK: %{{.*}} = fir.call @_QMmod1Pdev1
252+
! CHECK: hlfir.assign
253+
! CHECK-NOT: cuf.data_transfer

0 commit comments

Comments
 (0)