Skip to content

[flang][cuda] Improve data transfer detection by filtering symbols #98378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 11, 2024

Conversation

clementval
Copy link
Contributor

@clementval clementval commented Jul 10, 2024

The current data transfer detection was collecting too many symbol and made wrong decision. This patch introduces a new function CollectCudaSymbols that is different than CollectSymbols and collect only symbol of interest for cuda data transfer in an expression.

Currently two cases where symbols are filtered out are:

  • array subscripts: only the array symbol is on interest, the indexing can be filtered out
  • function arguments: symbols of the function arguments are filtered out.

This fix some false positive data transfer and implicit data transfer.

More filtering might be needed and will be added as follow up patches.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:semantics labels Jul 10, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-semantics

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

The current data transfer detection was collecting too many symbol and made wrong decision. This patch introduces a new function CollectCudaSymbols that is different than CollectSymbols and collect only symbol of interest for cuda data transfer in an expression.

Currently two cases where symbols are filtered out are:

  • array subscripts: only the array symbol is on interest, the indexing can be filtered out
  • function arguments: symbols of the function arguments are filtered out.

Full diff: https://github.com/llvm/llvm-project/pull/98378.diff

3 Files Affected:

  • (modified) flang/include/flang/Evaluate/tools.h (+12-2)
  • (modified) flang/lib/Evaluate/tools.cpp (+29)
  • (modified) flang/test/Lower/CUDA/cuda-data-transfer.cuf (+31-2)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 625f9e5f6576f..8555073a2d0d4 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1073,6 +1073,16 @@ extern template semantics::UnorderedSymbolSet CollectSymbols(
 extern template semantics::UnorderedSymbolSet CollectSymbols(
     const Expr<SubscriptInteger> &);
 
+// Collects Symbols of interest for the CUDA data transfer in an expression
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeType> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeInteger> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SubscriptInteger> &);
+
 // Predicate: does a variable contain a vector-valued subscript (not a triplet)?
 bool HasVectorSubscript(const Expr<SomeType> &);
 
@@ -1236,7 +1246,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
 // Get the number of distinct symbols with CUDA attribute in the expression.
 template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
   semantics::UnorderedSymbolSet symbols;
-  for (const Symbol &sym : CollectSymbols(expr)) {
+  for (const Symbol &sym : CollectCudaSymbols(expr)) {
     if (const auto *details =
             sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
       if (details->cudaDataAttr() &&
@@ -1259,7 +1269,7 @@ template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
 inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
   unsigned hostSymbols{0};
   unsigned deviceSymbols{0};
-  for (const Symbol &sym : CollectSymbols(expr)) {
+  for (const Symbol &sym : CollectCudaSymbols(expr)) {
     if (const auto *details =
             sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
       if (details->cudaDataAttr() &&
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index a5f4faa0cef8f..34faba39ffd46 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1000,6 +1000,35 @@ template semantics::UnorderedSymbolSet CollectSymbols(
 template semantics::UnorderedSymbolSet CollectSymbols(
     const Expr<SubscriptInteger> &);
 
+struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
+                                      semantics::UnorderedSymbolSet> {
+  using Base =
+      SetTraverse<CollectCudaSymbolsHelper, semantics::UnorderedSymbolSet>;
+  CollectCudaSymbolsHelper() : Base{*this} {}
+  using Base::operator();
+  semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
+    return {symbol};
+  }
+  // Overload some of the operator() to filter out the symbols that are not
+  // of interest for CUDA data transfer logic.
+  semantics::UnorderedSymbolSet operator()(const Subscript &) const {
+    return {};
+  }
+  semantics::UnorderedSymbolSet operator()(const ProcedureRef &) const {
+    return {};
+  }
+};
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &x) {
+  return CollectCudaSymbolsHelper{}(x);
+}
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeType> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeInteger> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SubscriptInteger> &);
+
 // HasVectorSubscript()
 struct HasVectorSubscriptHelper
     : public AnyTraverse<HasVectorSubscriptHelper, bool,
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 0191de748d3eb..4929e1dcfabfc 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -6,6 +6,12 @@ module mod1
   type :: t1
     integer :: i
   end type
+
+contains
+  function dev1(a)
+    integer, device :: a(:)
+    integer :: dev1
+  end function
 end
 
 subroutine sub1()
@@ -213,11 +219,34 @@ subroutine sub10(a, b)
   res = a + b
 end subroutine
 
-
-
 ! CHECK-LABEL: func.func @_QPsub10(
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}
 
 ! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub10Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK: cuf.data_transfer %[[A]]#1 to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
 ! CHECK-NOT: cuf.data_transfer
+
+subroutine sub11(n)
+  integer :: n
+  real, dimension(10) :: h
+  real, dimension(n), device :: d
+  do i=1,10
+    h(i) = d(i)
+  end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub11
+! CHECK: %[[RHS:.*]] = hlfir.designate %{{.*}} (%{{.*}})  : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+! CHECK: %[[LHS:.*]] = hlfir.designate %{{.*}} (%{{.*}})  : (!fir.ref<!fir.array<10xf32>>, i64) -> !fir.ref<f32>
+! CHECK: cuf.data_transfer %[[RHS]] to %[[LHS]] {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<f32>, !fir.ref<f32>
+
+subroutine sub12()
+  use mod1
+  integer, device :: a(10)
+  integer :: x
+  x = dev1(a)
+end subroutine
+
+! CHECK: %{{.*}} = fir.call @_QMmod1Pdev1
+! CHECK: hlfir.assign
+! CHECK-NOT: cuf.data_transfer

@clementval clementval merged commit e66ea43 into llvm:main Jul 11, 2024
7 checks passed
@clementval clementval deleted the temp_data_transfer branch July 11, 2024 16:36
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
…lvm#98378)

The current data transfer detection was collecting too many symbol and
made wrong decision. This patch introduces a new function
`CollectCudaSymbols` that is different than `CollectSymbols` and collect
only symbol of interest for cuda data transfer in an expression.

Currently two cases where symbols are filtered out are: 
- array subscripts: only the array symbol is on interest, the indexing
can be filtered out
- function arguments: symbols of the function arguments are filtered
out.

This fix some false positive data transfer and implicit data transfer. 

More filtering might be needed and will be added as follow up patches.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants