Skip to content

[flang][cuda] Lower attribute for dummy argument #81212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 9, 2024

Conversation

clementval
Copy link
Contributor

Lower CUDA attribute for simple dummy argument. This is done in a similar way than TARGET, OPTIONAL and so on.

This patch also move the Fortran::common::CUDADataAttr to fir::CUDAAttributeAttr mapping to flang/include/flang/Optimizer/Support/Utils.h so that it can be reused where needed.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 9, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 9, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Lower CUDA attribute for simple dummy argument. This is done in a similar way than TARGET, OPTIONAL and so on.

This patch also move the Fortran::common::CUDADataAttr to fir::CUDAAttributeAttr mapping to flang/include/flang/Optimizer/Support/Utils.h so that it can be reused where needed.


Full diff: https://github.com/llvm/llvm-project/pull/81212.diff

5 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROpsSupport.h (+3)
  • (modified) flang/include/flang/Optimizer/Support/Utils.h (+30)
  • (modified) flang/lib/Lower/CallInterface.cpp (+5)
  • (modified) flang/lib/Lower/ConvertVariable.cpp (+2-26)
  • (modified) flang/test/Lower/CUDA/cuda-data-attribute.cuf (+34-1)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index 977949e399a53b..6ac6a3116d40b0 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -72,6 +72,9 @@ constexpr llvm::StringRef getOptionalAttrName() { return "fir.optional"; }
 /// Attribute to mark Fortran entities with the TARGET attribute.
 static constexpr llvm::StringRef getTargetAttrName() { return "fir.target"; }
 
+/// Attribute to mark Fortran entities with the CUDA attribute.
+static constexpr llvm::StringRef getCUDAAttrName() { return "fir.cuda_attr"; }
+
 /// Attribute to mark that a function argument is a character dummy procedure.
 /// Character dummy procedure have special ABI constraints.
 static constexpr llvm::StringRef getCharacterProcedureDummyAttrName() {
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index b50f297a7d3141..586701b4c54d14 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -273,6 +273,36 @@ inline void genMinMaxlocReductionLoop(
   builder.setInsertionPointAfter(ifMaskTrueOp);
 }
 
+inline fir::CUDAAttributeAttr
+getCUDAAttribute(mlir::MLIRContext *mlirContext,
+                 std::optional<Fortran::common::CUDADataAttr> cudaAttr) {
+  if (cudaAttr) {
+    fir::CUDAAttribute attr;
+    switch (*cudaAttr) {
+    case Fortran::common::CUDADataAttr::Constant:
+      attr = fir::CUDAAttribute::Constant;
+      break;
+    case Fortran::common::CUDADataAttr::Device:
+      attr = fir::CUDAAttribute::Device;
+      break;
+    case Fortran::common::CUDADataAttr::Managed:
+      attr = fir::CUDAAttribute::Managed;
+      break;
+    case Fortran::common::CUDADataAttr::Pinned:
+      attr = fir::CUDAAttribute::Pinned;
+      break;
+    case Fortran::common::CUDADataAttr::Shared:
+      attr = fir::CUDAAttribute::Shared;
+      break;
+    case Fortran::common::CUDADataAttr::Texture:
+      // Obsolete attribute
+      return {};
+    }
+    return fir::CUDAAttributeAttr::get(mlirContext, attr);
+  }
+  return {};
+}
+
 } // namespace fir
 
 #endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index b007c958cb6b33..4c297ceffc536d 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -19,6 +19,7 @@
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Optimizer/Support/InternalNames.h"
+#include "flang/Optimizer/Support/Utils.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include <optional>
@@ -993,6 +994,10 @@ class Fortran::lower::CallInterfaceImpl {
       TODO(loc, "VOLATILE in procedure interface");
     if (obj.attrs.test(Attrs::Target))
       addMLIRAttr(fir::getTargetAttrName());
+    if (obj.cudaDataAttr)
+      attrs.emplace_back(
+          mlir::StringAttr::get(&mlirContext, fir::getCUDAAttrName()),
+          fir::getCUDAAttribute(&mlirContext, obj.cudaDataAttr));
 
     // TODO: intents that require special care (e.g finalization)
 
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index d57bdd448da3f6..f14267f1234217 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -37,6 +37,7 @@
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/Support/FatalError.h"
 #include "flang/Optimizer/Support/InternalNames.h"
+#include "flang/Optimizer/Support/Utils.h"
 #include "flang/Semantics/runtime-type-info.h"
 #include "flang/Semantics/tools.h"
 #include "llvm/Support/Debug.h"
@@ -1583,32 +1584,7 @@ fir::CUDAAttributeAttr Fortran::lower::translateSymbolCUDAAttribute(
     mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
   std::optional<Fortran::common::CUDADataAttr> cudaAttr =
       Fortran::semantics::GetCUDADataAttr(&sym);
-  if (cudaAttr) {
-    fir::CUDAAttribute attr;
-    switch (*cudaAttr) {
-    case Fortran::common::CUDADataAttr::Constant:
-      attr = fir::CUDAAttribute::Constant;
-      break;
-    case Fortran::common::CUDADataAttr::Device:
-      attr = fir::CUDAAttribute::Device;
-      break;
-    case Fortran::common::CUDADataAttr::Managed:
-      attr = fir::CUDAAttribute::Managed;
-      break;
-    case Fortran::common::CUDADataAttr::Pinned:
-      attr = fir::CUDAAttribute::Pinned;
-      break;
-    case Fortran::common::CUDADataAttr::Shared:
-      attr = fir::CUDAAttribute::Shared;
-      break;
-    case Fortran::common::CUDADataAttr::Texture:
-      // Obsolete attribute
-      return {};
-    }
-
-    return fir::CUDAAttributeAttr::get(mlirContext, attr);
-  }
-  return {};
+  return fir::getCUDAAttribute(mlirContext, cudaAttr);
 }
 
 /// Map a symbol to its FIR address and evaluated specification expressions.
diff --git a/flang/test/Lower/CUDA/cuda-data-attribute.cuf b/flang/test/Lower/CUDA/cuda-data-attribute.cuf
index caa8ac7baff383..b02701bf3aea5a 100644
--- a/flang/test/Lower/CUDA/cuda-data-attribute.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-attribute.cuf
@@ -1,7 +1,7 @@
 ! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
 ! RUN: bbc -emit-hlfir -fcuda %s -o - | fir-opt -convert-hlfir-to-fir | FileCheck %s --check-prefix=FIR
 
-! Test lowering of CUDA attribute on local variables.
+! Test lowering of CUDA attribute on variables.
 
 subroutine local_var_attrs
   real, constant :: rc
@@ -20,3 +20,36 @@ end subroutine
 ! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> !fir.ref<f32>
 ! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
 ! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
+
+subroutine dummy_arg_constant(dc)
+  real, constant :: dc
+end subroutine
+! CHECK-LABEL: func.func @_QPdummy_arg_constant(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "dc", fir.cuda_attr = #fir.cuda<constant>}
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFdummy_arg_constantEdc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+
+subroutine dummy_arg_device(dd)
+  real, device :: dd
+end subroutine
+! CHECK-LABEL: func.func @_QPdummy_arg_device(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "dd", fir.cuda_attr = #fir.cuda<device>}) {
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<device>, uniq_name = "_QFdummy_arg_deviceEdd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+
+subroutine dummy_arg_managed(dm)
+  real, allocatable, managed :: dm
+end subroutine
+! CHECK-LABEL: func.func @_QPdummy_arg_managed(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<f32>>> {fir.bindc_name = "dm", fir.cuda_attr = #fir.cuda<managed>}) {
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFdummy_arg_managedEdm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+
+subroutine dummy_arg_pinned(dp)
+  real, allocatable, pinned :: dp
+end subroutine
+! CHECK-LABEL: func.func @_QPdummy_arg_pinned(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<f32>>> {fir.bindc_name = "dp", fir.cuda_attr = #fir.cuda<pinned>}) {
+! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFdummy_arg_pinnedEdp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+
+
+
+
+

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@clementval clementval merged commit c560ce4 into llvm:main Feb 9, 2024
@clementval clementval deleted the cuda_dummy_arg_attr branch February 9, 2024 02:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants