Skip to content

Commit 2c1e3c4

Browse files
committed
[flang][cuda] Lower attribute for local variable
1 parent adbf21f commit 2c1e3c4

File tree

12 files changed

+121
-21
lines changed

12 files changed

+121
-21
lines changed

flang/include/flang/Lower/ConvertVariable.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ translateSymbolAttributes(mlir::MLIRContext *mlirContext,
137137
fir::FortranVariableFlagsEnum extraFlags =
138138
fir::FortranVariableFlagsEnum::None);
139139

140+
/// Translate the CUDA Fortran attributes of \p sym into the FIR CUDA attribute
141+
/// representation.
142+
fir::CUDAAttributeAttr
143+
translateSymbolCUDAAttribute(mlir::MLIRContext *mlirContext,
144+
const Fortran::semantics::Symbol &sym);
145+
140146
/// Map a symbol to a given fir::ExtendedValue. This will generate an
141147
/// hlfir.declare when lowering to HLFIR and map the hlfir.declare result to the
142148
/// symbol.

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder,
233233
fir::FortranVariableOpInterface fortranVariable);
234234

235235
/// Generate declaration for a fir::ExtendedValue in memory.
236-
fir::FortranVariableOpInterface genDeclare(mlir::Location loc,
237-
fir::FirOpBuilder &builder,
238-
const fir::ExtendedValue &exv,
239-
llvm::StringRef name,
240-
fir::FortranVariableFlagsAttr flags);
236+
fir::FortranVariableOpInterface
237+
genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
238+
const fir::ExtendedValue &exv, llvm::StringRef name,
239+
fir::FortranVariableFlagsAttr flags,
240+
fir::CUDAAttributeAttr cudaAttr = {});
241241

242242
/// Generate an hlfir.associate to build a variable from an expression value.
243243
/// The type of the variable must be provided so that scalar logicals are

flang/include/flang/Optimizer/Dialect/FIRAttr.td

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,28 @@ def fir_FortranVariableFlagsAttr : fir_Attr<"FortranVariableFlags"> {
5555
let returnType = "::fir::FortranVariableFlagsEnum";
5656
let convertFromStorage = "$_self.getFlags()";
5757
let constBuilderCall =
58-
"::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)";
58+
"::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)";
59+
}
60+
61+
def CUDAconstant : I32EnumAttrCase<"Constant", 0, "constant">;
62+
def CUDAdevice : I32EnumAttrCase<"Device", 1, "device">;
63+
def CUDAmanaged : I32EnumAttrCase<"Managed", 2, "managed">;
64+
def CUDApinned : I32EnumAttrCase<"Pinned", 3, "pinned">;
65+
def CUDAshared : I32EnumAttrCase<"Shared", 4, "shared">;
66+
def CUDAunified : I32EnumAttrCase<"Unified", 5, "unified">;
67+
// Texture is omitted since it is obsolete and rejected by semantic.
68+
69+
def fir_CUDAAttribute : I32EnumAttr<
70+
"CUDAAttribute",
71+
"CUDA Fortran variable attributes",
72+
[CUDAconstant, CUDAdevice, CUDAmanaged, CUDApinned, CUDAshared,
73+
CUDAunified]> {
74+
let genSpecializedAttr = 0;
75+
let cppNamespace = "::fir";
76+
}
77+
78+
def fir_CUDAAttributeAttr : EnumAttr<fir_Dialect, fir_CUDAAttribute, "cuda"> {
79+
let assemblyFormat = [{ ```<` $value `>` }];
5980
}
6081

6182
def fir_BoxFieldAttr : I32EnumAttr<

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3027,7 +3027,8 @@ def fir_DeclareOp : fir_Op<"declare", [AttrSizedOperandSegments,
30273027
Optional<AnyShapeOrShiftType>:$shape,
30283028
Variadic<AnyIntegerType>:$typeparams,
30293029
Builtin_StringAttr:$uniq_name,
3030-
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs
3030+
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs,
3031+
OptionalAttr<fir_CUDAAttributeAttr>:$cuda_attr
30313032
);
30323033

30333034
let results = (outs AnyRefOrBox);

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments,
8888
Optional<AnyShapeOrShiftType>:$shape,
8989
Variadic<AnyIntegerType>:$typeparams,
9090
Builtin_StringAttr:$uniq_name,
91-
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs
91+
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs,
92+
OptionalAttr<fir_CUDAAttributeAttr>:$cuda_attr
9293
);
9394

9495
let results = (outs AnyFortranVariable, AnyRefOrBoxLike);
@@ -101,7 +102,8 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments,
101102
let builders = [
102103
OpBuilder<(ins "mlir::Value":$memref, "llvm::StringRef":$uniq_name,
103104
CArg<"mlir::Value", "{}">:$shape, CArg<"mlir::ValueRange", "{}">:$typeparams,
104-
CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs)>];
105+
CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs,
106+
CArg<"fir::CUDAAttributeAttr", "{}">:$cuda_attr)>];
105107

106108
let extraClassDeclaration = [{
107109
/// Get the variable original base (same as input). It lacks

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,38 @@ fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes(
15791579
return fir::FortranVariableFlagsAttr::get(mlirContext, flags);
15801580
}
15811581

1582+
fir::CUDAAttributeAttr Fortran::lower::translateSymbolCUDAAttribute(
1583+
mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
1584+
std::optional<Fortran::common::CUDADataAttr> cudaAttr =
1585+
Fortran::semantics::GetCUDADataAttr(&sym);
1586+
if (cudaAttr) {
1587+
fir::CUDAAttribute attr;
1588+
switch (*cudaAttr) {
1589+
case Fortran::common::CUDADataAttr::Constant:
1590+
attr = fir::CUDAAttribute::Constant;
1591+
break;
1592+
case Fortran::common::CUDADataAttr::Device:
1593+
attr = fir::CUDAAttribute::Device;
1594+
break;
1595+
case Fortran::common::CUDADataAttr::Managed:
1596+
attr = fir::CUDAAttribute::Managed;
1597+
break;
1598+
case Fortran::common::CUDADataAttr::Pinned:
1599+
attr = fir::CUDAAttribute::Pinned;
1600+
break;
1601+
case Fortran::common::CUDADataAttr::Shared:
1602+
attr = fir::CUDAAttribute::Shared;
1603+
break;
1604+
case Fortran::common::CUDADataAttr::Texture:
1605+
// Obsolete attribute
1606+
break;
1607+
}
1608+
1609+
return fir::CUDAAttributeAttr::get(mlirContext, attr);
1610+
}
1611+
return {};
1612+
}
1613+
15821614
/// Map a symbol to its FIR address and evaluated specification expressions.
15831615
/// Not for symbols lowered to fir.box.
15841616
/// Will optionally create fir.declare.
@@ -1618,6 +1650,8 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
16181650
auto name = converter.mangleName(sym);
16191651
fir::FortranVariableFlagsAttr attributes =
16201652
Fortran::lower::translateSymbolAttributes(builder.getContext(), sym);
1653+
fir::CUDAAttributeAttr cudaAttr =
1654+
Fortran::lower::translateSymbolCUDAAttribute(builder.getContext(), sym);
16211655

16221656
if (isCrayPointee) {
16231657
mlir::Type baseType =
@@ -1664,7 +1698,7 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
16641698
return;
16651699
}
16661700
auto newBase = builder.create<hlfir::DeclareOp>(
1667-
loc, base, name, shapeOrShift, lenParams, attributes);
1701+
loc, base, name, shapeOrShift, lenParams, attributes, cudaAttr);
16681702
symMap.addVariableDefinition(sym, newBase, force);
16691703
return;
16701704
}
@@ -1709,9 +1743,12 @@ void Fortran::lower::genDeclareSymbol(
17091743
fir::FortranVariableFlagsAttr attributes =
17101744
Fortran::lower::translateSymbolAttributes(
17111745
builder.getContext(), sym.GetUltimate(), extraFlags);
1746+
fir::CUDAAttributeAttr cudaAttr =
1747+
Fortran::lower::translateSymbolCUDAAttribute(builder.getContext(),
1748+
sym.GetUltimate());
17121749
auto name = converter.mangleName(sym);
17131750
hlfir::EntityWithAttributes declare =
1714-
hlfir::genDeclare(loc, builder, exv, name, attributes);
1751+
hlfir::genDeclare(loc, builder, exv, name, attributes, cudaAttr);
17151752
symMap.addVariableDefinition(sym, declare.getIfVariableInterface(), force);
17161753
return;
17171754
}

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ mlir::Value hlfir::Entity::getFirBase() const {
198198
fir::FortranVariableOpInterface
199199
hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
200200
const fir::ExtendedValue &exv, llvm::StringRef name,
201-
fir::FortranVariableFlagsAttr flags) {
201+
fir::FortranVariableFlagsAttr flags,
202+
fir::CUDAAttributeAttr cudaAttr) {
202203

203204
mlir::Value base = fir::getBase(exv);
204205
assert(fir::conformsWithPassByRef(base.getType()) &&
@@ -228,7 +229,7 @@ hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
228229
},
229230
[](const auto &) {});
230231
auto declareOp = builder.create<hlfir::DeclareOp>(
231-
loc, base, name, shapeOrShift, lenParams, flags);
232+
loc, base, name, shapeOrShift, lenParams, flags, cudaAttr);
232233
return mlir::cast<fir::FortranVariableOpInterface>(declareOp.getOperation());
233234
}
234235

flang/lib/Optimizer/Dialect/FIRAttr.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "flang/Optimizer/Dialect/FIRDialect.h"
1515
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
1616
#include "mlir/IR/AttributeSupport.h"
17+
#include "mlir/IR/Builders.h"
1718
#include "mlir/IR/BuiltinTypes.h"
1819
#include "mlir/IR/DialectImplementation.h"
1920
#include "llvm/ADT/SmallString.h"
@@ -297,5 +298,5 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
297298
void FIROpsDialect::registerAttributes() {
298299
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
299300
LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
300-
UpperBoundAttr>();
301+
UpperBoundAttr, CUDAAttributeAttr>();
301302
}

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,15 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
123123
mlir::OperationState &result, mlir::Value memref,
124124
llvm::StringRef uniq_name, mlir::Value shape,
125125
mlir::ValueRange typeparams,
126-
fir::FortranVariableFlagsAttr fortran_attrs) {
126+
fir::FortranVariableFlagsAttr fortran_attrs,
127+
fir::CUDAAttributeAttr cuda_attr) {
127128
auto nameAttr = builder.getStringAttr(uniq_name);
128129
mlir::Type inputType = memref.getType();
129130
bool hasExplicitLbs = hasExplicitLowerBounds(shape);
130131
mlir::Type hlfirVariableType =
131132
getHLFIRVariableType(inputType, hasExplicitLbs);
132133
build(builder, result, {hlfirVariableType, inputType}, memref, shape,
133-
typeparams, nameAttr, fortran_attrs);
134+
typeparams, nameAttr, fortran_attrs, cuda_attr);
134135
}
135136

136137
mlir::LogicalResult hlfir::DeclareOp::verify() {

flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,16 @@ class DeclareOpConversion : public mlir::OpRewritePattern<hlfir::DeclareOp> {
320320
mlir::Location loc = declareOp->getLoc();
321321
mlir::Value memref = declareOp.getMemref();
322322
fir::FortranVariableFlagsAttr fortranAttrs;
323+
fir::CUDAAttributeAttr cudaAttr;
323324
if (auto attrs = declareOp.getFortranAttrs())
324325
fortranAttrs =
325326
fir::FortranVariableFlagsAttr::get(rewriter.getContext(), *attrs);
327+
if (auto attr = declareOp.getCudaAttr())
328+
cudaAttr = fir::CUDAAttributeAttr::get(rewriter.getContext(), *attr);
326329
auto firDeclareOp = rewriter.create<fir::DeclareOp>(
327330
loc, memref.getType(), memref, declareOp.getShape(),
328-
declareOp.getTypeparams(), declareOp.getUniqName(), fortranAttrs);
331+
declareOp.getTypeparams(), declareOp.getUniqName(), fortranAttrs,
332+
cudaAttr);
329333

330334
// Propagate other attributes from hlfir.declare to fir.declare.
331335
// OpenACC's acc.declare is one example. Right now, the propagation
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
! RUN: bbc -emit-hlfir -fcuda %s -o - | fir-opt -convert-hlfir-to-fir | FileCheck %s --check-prefix=FIR
3+
4+
! Test lowering of CUDA attribute on local variables.
5+
6+
subroutine local_var_attrs
7+
real, constant :: rc
8+
real, device :: rd
9+
real, allocatable, managed :: rm
10+
real, allocatable, pinned :: rp
11+
end subroutine
12+
13+
! CHECK-LABEL: func.func @_QPlocal_var_attrs()
14+
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
15+
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
16+
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
17+
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
18+
19+
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> !fir.ref<f32>
20+
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> !fir.ref<f32>
21+
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
22+
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>

flang/unittests/Optimizer/FortranVariableTest.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ TEST_F(FortranVariableTest, SimpleScalar) {
4949
auto name = mlir::StringAttr::get(&context, "x");
5050
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
5151
/*shape=*/mlir::Value{}, /*typeParams=*/std::nullopt, name,
52-
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{});
52+
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
53+
/*cuda_attr=*/fir::CUDAAttributeAttr{});
5354

5455
fir::FortranVariableOpInterface fortranVariable = declare;
5556
EXPECT_FALSE(fortranVariable.isArray());
@@ -74,7 +75,8 @@ TEST_F(FortranVariableTest, CharacterScalar) {
7475
auto name = mlir::StringAttr::get(&context, "x");
7576
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
7677
/*shape=*/mlir::Value{}, typeParams, name,
77-
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{});
78+
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
79+
/*cuda_attr=*/fir::CUDAAttributeAttr{});
7880

7981
fir::FortranVariableOpInterface fortranVariable = declare;
8082
EXPECT_FALSE(fortranVariable.isArray());
@@ -104,7 +106,8 @@ TEST_F(FortranVariableTest, SimpleArray) {
104106
auto name = mlir::StringAttr::get(&context, "x");
105107
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
106108
shape, /*typeParams*/ std::nullopt, name,
107-
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{});
109+
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
110+
/*cuda_attr=*/fir::CUDAAttributeAttr{});
108111

109112
fir::FortranVariableOpInterface fortranVariable = declare;
110113
EXPECT_TRUE(fortranVariable.isArray());
@@ -134,7 +137,8 @@ TEST_F(FortranVariableTest, CharacterArray) {
134137
auto name = mlir::StringAttr::get(&context, "x");
135138
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
136139
shape, typeParams, name,
137-
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{});
140+
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
141+
/*cuda_attr=*/fir::CUDAAttributeAttr{});
138142

139143
fir::FortranVariableOpInterface fortranVariable = declare;
140144
EXPECT_TRUE(fortranVariable.isArray());

0 commit comments

Comments
 (0)