-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][cuda] Improve data transfer detection by filtering symbols #98378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-semantics Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThe current data transfer detection was collecting too many symbol and made wrong decision. This patch introduces a new function Currently two cases where symbols are filtered out are:
Full diff: https://github.com/llvm/llvm-project/pull/98378.diff 3 Files Affected:
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 625f9e5f6576f..8555073a2d0d4 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1073,6 +1073,16 @@ extern template semantics::UnorderedSymbolSet CollectSymbols(
extern template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SubscriptInteger> &);
+// Collects Symbols of interest for the CUDA data transfer in an expression
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SomeType> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SomeInteger> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SubscriptInteger> &);
+
// Predicate: does a variable contain a vector-valued subscript (not a triplet)?
bool HasVectorSubscript(const Expr<SomeType> &);
@@ -1236,7 +1246,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
// Get the number of distinct symbols with CUDA attribute in the expression.
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
semantics::UnorderedSymbolSet symbols;
- for (const Symbol &sym : CollectSymbols(expr)) {
+ for (const Symbol &sym : CollectCudaSymbols(expr)) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
@@ -1259,7 +1269,7 @@ template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
unsigned hostSymbols{0};
unsigned deviceSymbols{0};
- for (const Symbol &sym : CollectSymbols(expr)) {
+ for (const Symbol &sym : CollectCudaSymbols(expr)) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index a5f4faa0cef8f..34faba39ffd46 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1000,6 +1000,35 @@ template semantics::UnorderedSymbolSet CollectSymbols(
template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SubscriptInteger> &);
+struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
+ semantics::UnorderedSymbolSet> {
+ using Base =
+ SetTraverse<CollectCudaSymbolsHelper, semantics::UnorderedSymbolSet>;
+ CollectCudaSymbolsHelper() : Base{*this} {}
+ using Base::operator();
+ semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
+ return {symbol};
+ }
+ // Overload some of the operator() to filter out the symbols that are not
+ // of interest for CUDA data transfer logic.
+ semantics::UnorderedSymbolSet operator()(const Subscript &) const {
+ return {};
+ }
+ semantics::UnorderedSymbolSet operator()(const ProcedureRef &) const {
+ return {};
+ }
+};
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &x) {
+ return CollectCudaSymbolsHelper{}(x);
+}
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SomeType> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SomeInteger> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SubscriptInteger> &);
+
// HasVectorSubscript()
struct HasVectorSubscriptHelper
: public AnyTraverse<HasVectorSubscriptHelper, bool,
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 0191de748d3eb..4929e1dcfabfc 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -6,6 +6,12 @@ module mod1
type :: t1
integer :: i
end type
+
+contains
+ function dev1(a)
+ integer, device :: a(:)
+ integer :: dev1
+ end function
end
subroutine sub1()
@@ -213,11 +219,34 @@ subroutine sub10(a, b)
res = a + b
end subroutine
-
-
! CHECK-LABEL: func.func @_QPsub10(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}
! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub10Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: cuf.data_transfer %[[A]]#1 to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
! CHECK-NOT: cuf.data_transfer
+
+subroutine sub11(n)
+ integer :: n
+ real, dimension(10) :: h
+ real, dimension(n), device :: d
+ do i=1,10
+ h(i) = d(i)
+ end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub11
+! CHECK: %[[RHS:.*]] = hlfir.designate %{{.*}} (%{{.*}}) : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+! CHECK: %[[LHS:.*]] = hlfir.designate %{{.*}} (%{{.*}}) : (!fir.ref<!fir.array<10xf32>>, i64) -> !fir.ref<f32>
+! CHECK: cuf.data_transfer %[[RHS]] to %[[LHS]] {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<f32>, !fir.ref<f32>
+
+subroutine sub12()
+ use mod1
+ integer, device :: a(10)
+ integer :: x
+ x = dev1(a)
+end subroutine
+
+! CHECK: %{{.*}} = fir.call @_QMmod1Pdev1
+! CHECK: hlfir.assign
+! CHECK-NOT: cuf.data_transfer
|
…lvm#98378) The current data transfer detection was collecting too many symbol and made wrong decision. This patch introduces a new function `CollectCudaSymbols` that is different than `CollectSymbols` and collect only symbol of interest for cuda data transfer in an expression. Currently two cases where symbols are filtered out are: - array subscripts: only the array symbol is on interest, the indexing can be filtered out - function arguments: symbols of the function arguments are filtered out. This fix some false positive data transfer and implicit data transfer. More filtering might be needed and will be added as follow up patches.
The current data transfer detection was collecting too many symbol and made wrong decision. This patch introduces a new function
CollectCudaSymbols
that is different thanCollectSymbols
and collect only symbol of interest for cuda data transfer in an expression.Currently two cases where symbols are filtered out are:
This fix some false positive data transfer and implicit data transfer.
More filtering might be needed and will be added as follow up patches.