Skip to content

Commit c560ce4

Browse files
authored
[flang][cuda] Lower attribute for dummy argument (#81212)
Lower CUDA attribute for simple dummy argument. This is done in a similar way than `TARGET`, `OPTIONAL` and so on. This patch also move the `Fortran::common::CUDADataAttr` to `fir::CUDAAttributeAttr` mapping to `flang/include/flang/Optimizer/Support/Utils.h` so that it can be reused where needed.
1 parent ffabcbc commit c560ce4

File tree

5 files changed

+74
-27
lines changed

5 files changed

+74
-27
lines changed

flang/include/flang/Optimizer/Dialect/FIROpsSupport.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ constexpr llvm::StringRef getOptionalAttrName() { return "fir.optional"; }
7272
/// Attribute to mark Fortran entities with the TARGET attribute.
7373
static constexpr llvm::StringRef getTargetAttrName() { return "fir.target"; }
7474

75+
/// Attribute to mark Fortran entities with the CUDA attribute.
76+
static constexpr llvm::StringRef getCUDAAttrName() { return "fir.cuda_attr"; }
77+
7578
/// Attribute to mark that a function argument is a character dummy procedure.
7679
/// Character dummy procedure have special ABI constraints.
7780
static constexpr llvm::StringRef getCharacterProcedureDummyAttrName() {

flang/include/flang/Optimizer/Support/Utils.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,36 @@ inline void genMinMaxlocReductionLoop(
273273
builder.setInsertionPointAfter(ifMaskTrueOp);
274274
}
275275

276+
inline fir::CUDAAttributeAttr
277+
getCUDAAttribute(mlir::MLIRContext *mlirContext,
278+
std::optional<Fortran::common::CUDADataAttr> cudaAttr) {
279+
if (cudaAttr) {
280+
fir::CUDAAttribute attr;
281+
switch (*cudaAttr) {
282+
case Fortran::common::CUDADataAttr::Constant:
283+
attr = fir::CUDAAttribute::Constant;
284+
break;
285+
case Fortran::common::CUDADataAttr::Device:
286+
attr = fir::CUDAAttribute::Device;
287+
break;
288+
case Fortran::common::CUDADataAttr::Managed:
289+
attr = fir::CUDAAttribute::Managed;
290+
break;
291+
case Fortran::common::CUDADataAttr::Pinned:
292+
attr = fir::CUDAAttribute::Pinned;
293+
break;
294+
case Fortran::common::CUDADataAttr::Shared:
295+
attr = fir::CUDAAttribute::Shared;
296+
break;
297+
case Fortran::common::CUDADataAttr::Texture:
298+
// Obsolete attribute
299+
return {};
300+
}
301+
return fir::CUDAAttributeAttr::get(mlirContext, attr);
302+
}
303+
return {};
304+
}
305+
276306
} // namespace fir
277307

278308
#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H

flang/lib/Lower/CallInterface.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "flang/Optimizer/Dialect/FIRDialect.h"
2020
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
2121
#include "flang/Optimizer/Support/InternalNames.h"
22+
#include "flang/Optimizer/Support/Utils.h"
2223
#include "flang/Semantics/symbol.h"
2324
#include "flang/Semantics/tools.h"
2425
#include <optional>
@@ -993,6 +994,10 @@ class Fortran::lower::CallInterfaceImpl {
993994
TODO(loc, "VOLATILE in procedure interface");
994995
if (obj.attrs.test(Attrs::Target))
995996
addMLIRAttr(fir::getTargetAttrName());
997+
if (obj.cudaDataAttr)
998+
attrs.emplace_back(
999+
mlir::StringAttr::get(&mlirContext, fir::getCUDAAttrName()),
1000+
fir::getCUDAAttribute(&mlirContext, obj.cudaDataAttr));
9961001

9971002
// TODO: intents that require special care (e.g finalization)
9981003

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "flang/Optimizer/HLFIR/HLFIROps.h"
3838
#include "flang/Optimizer/Support/FatalError.h"
3939
#include "flang/Optimizer/Support/InternalNames.h"
40+
#include "flang/Optimizer/Support/Utils.h"
4041
#include "flang/Semantics/runtime-type-info.h"
4142
#include "flang/Semantics/tools.h"
4243
#include "llvm/Support/Debug.h"
@@ -1583,32 +1584,7 @@ fir::CUDAAttributeAttr Fortran::lower::translateSymbolCUDAAttribute(
15831584
mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
15841585
std::optional<Fortran::common::CUDADataAttr> cudaAttr =
15851586
Fortran::semantics::GetCUDADataAttr(&sym);
1586-
if (cudaAttr) {
1587-
fir::CUDAAttribute attr;
1588-
switch (*cudaAttr) {
1589-
case Fortran::common::CUDADataAttr::Constant:
1590-
attr = fir::CUDAAttribute::Constant;
1591-
break;
1592-
case Fortran::common::CUDADataAttr::Device:
1593-
attr = fir::CUDAAttribute::Device;
1594-
break;
1595-
case Fortran::common::CUDADataAttr::Managed:
1596-
attr = fir::CUDAAttribute::Managed;
1597-
break;
1598-
case Fortran::common::CUDADataAttr::Pinned:
1599-
attr = fir::CUDAAttribute::Pinned;
1600-
break;
1601-
case Fortran::common::CUDADataAttr::Shared:
1602-
attr = fir::CUDAAttribute::Shared;
1603-
break;
1604-
case Fortran::common::CUDADataAttr::Texture:
1605-
// Obsolete attribute
1606-
return {};
1607-
}
1608-
1609-
return fir::CUDAAttributeAttr::get(mlirContext, attr);
1610-
}
1611-
return {};
1587+
return fir::getCUDAAttribute(mlirContext, cudaAttr);
16121588
}
16131589

16141590
/// Map a symbol to its FIR address and evaluated specification expressions.

flang/test/Lower/CUDA/cuda-data-attribute.cuf

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
22
! RUN: bbc -emit-hlfir -fcuda %s -o - | fir-opt -convert-hlfir-to-fir | FileCheck %s --check-prefix=FIR
33

4-
! Test lowering of CUDA attribute on local variables.
4+
! Test lowering of CUDA attribute on variables.
55

66
subroutine local_var_attrs
77
real, constant :: rc
@@ -20,3 +20,36 @@ end subroutine
2020
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> !fir.ref<f32>
2121
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
2222
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
23+
24+
subroutine dummy_arg_constant(dc)
25+
real, constant :: dc
26+
end subroutine
27+
! CHECK-LABEL: func.func @_QPdummy_arg_constant(
28+
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "dc", fir.cuda_attr = #fir.cuda<constant>}
29+
! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFdummy_arg_constantEdc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
30+
31+
subroutine dummy_arg_device(dd)
32+
real, device :: dd
33+
end subroutine
34+
! CHECK-LABEL: func.func @_QPdummy_arg_device(
35+
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32> {fir.bindc_name = "dd", fir.cuda_attr = #fir.cuda<device>}) {
36+
! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<device>, uniq_name = "_QFdummy_arg_deviceEdd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
37+
38+
subroutine dummy_arg_managed(dm)
39+
real, allocatable, managed :: dm
40+
end subroutine
41+
! CHECK-LABEL: func.func @_QPdummy_arg_managed(
42+
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<f32>>> {fir.bindc_name = "dm", fir.cuda_attr = #fir.cuda<managed>}) {
43+
! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFdummy_arg_managedEdm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
44+
45+
subroutine dummy_arg_pinned(dp)
46+
real, allocatable, pinned :: dp
47+
end subroutine
48+
! CHECK-LABEL: func.func @_QPdummy_arg_pinned(
49+
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<f32>>> {fir.bindc_name = "dp", fir.cuda_attr = #fir.cuda<pinned>}) {
50+
! CHECK: %{{.*}}:2 = hlfir.declare %[[ARG0]] {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFdummy_arg_pinnedEdp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
51+
52+
53+
54+
55+

0 commit comments

Comments
 (0)