Skip to content

Commit abc4f74

Browse files
authored
[flang][cuda] Lower attribute for local variable (#81076)
This is a first simple patch to introduce a new FIR attribute to carry the CUDA variable attribute information to hlfir.declare and fir.declare operations. It currently lowers this information for local variables. The texture attribute is omitted since it is rejected by semantic and will not make its way to MLIR. This new attribute is added as optional attribute to the hlfir.declare and fir.declare operations.
1 parent 758fd59 commit abc4f74

File tree

12 files changed

+121
-21
lines changed

12 files changed

+121
-21
lines changed

flang/include/flang/Lower/ConvertVariable.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ translateSymbolAttributes(mlir::MLIRContext *mlirContext,
137137
fir::FortranVariableFlagsEnum extraFlags =
138138
fir::FortranVariableFlagsEnum::None);
139139

140+
/// Translate the CUDA Fortran attributes of \p sym into the FIR CUDA attribute
141+
/// representation.
142+
fir::CUDAAttributeAttr
143+
translateSymbolCUDAAttribute(mlir::MLIRContext *mlirContext,
144+
const Fortran::semantics::Symbol &sym);
145+
140146
/// Map a symbol to a given fir::ExtendedValue. This will generate an
141147
/// hlfir.declare when lowering to HLFIR and map the hlfir.declare result to the
142148
/// symbol.

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ translateToExtendedValue(mlir::Location loc, fir::FirOpBuilder &builder,
233233
fir::FortranVariableOpInterface fortranVariable);
234234

235235
/// Generate declaration for a fir::ExtendedValue in memory.
236-
fir::FortranVariableOpInterface genDeclare(mlir::Location loc,
237-
fir::FirOpBuilder &builder,
238-
const fir::ExtendedValue &exv,
239-
llvm::StringRef name,
240-
fir::FortranVariableFlagsAttr flags);
236+
fir::FortranVariableOpInterface
237+
genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
238+
const fir::ExtendedValue &exv, llvm::StringRef name,
239+
fir::FortranVariableFlagsAttr flags,
240+
fir::CUDAAttributeAttr cudaAttr = {});
241241

242242
/// Generate an hlfir.associate to build a variable from an expression value.
243243
/// The type of the variable must be provided so that scalar logicals are

flang/include/flang/Optimizer/Dialect/FIRAttr.td

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,28 @@ def fir_FortranVariableFlagsAttr : fir_Attr<"FortranVariableFlags"> {
5555
let returnType = "::fir::FortranVariableFlagsEnum";
5656
let convertFromStorage = "$_self.getFlags()";
5757
let constBuilderCall =
58-
"::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)";
58+
"::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)";
59+
}
60+
61+
def CUDAconstant : I32EnumAttrCase<"Constant", 0, "constant">;
62+
def CUDAdevice : I32EnumAttrCase<"Device", 1, "device">;
63+
def CUDAmanaged : I32EnumAttrCase<"Managed", 2, "managed">;
64+
def CUDApinned : I32EnumAttrCase<"Pinned", 3, "pinned">;
65+
def CUDAshared : I32EnumAttrCase<"Shared", 4, "shared">;
66+
def CUDAunified : I32EnumAttrCase<"Unified", 5, "unified">;
67+
// Texture is omitted since it is obsolete and rejected by semantic.
68+
69+
def fir_CUDAAttribute : I32EnumAttr<
70+
"CUDAAttribute",
71+
"CUDA Fortran variable attributes",
72+
[CUDAconstant, CUDAdevice, CUDAmanaged, CUDApinned, CUDAshared,
73+
CUDAunified]> {
74+
let genSpecializedAttr = 0;
75+
let cppNamespace = "::fir";
76+
}
77+
78+
def fir_CUDAAttributeAttr : EnumAttr<fir_Dialect, fir_CUDAAttribute, "cuda"> {
79+
let assemblyFormat = [{ ```<` $value `>` }];
5980
}
6081

6182
def fir_BoxFieldAttr : I32EnumAttr<

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3027,7 +3027,8 @@ def fir_DeclareOp : fir_Op<"declare", [AttrSizedOperandSegments,
30273027
Optional<AnyShapeOrShiftType>:$shape,
30283028
Variadic<AnyIntegerType>:$typeparams,
30293029
Builtin_StringAttr:$uniq_name,
3030-
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs
3030+
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs,
3031+
OptionalAttr<fir_CUDAAttributeAttr>:$cuda_attr
30313032
);
30323033

30333034
let results = (outs AnyRefOrBox);

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments,
8888
Optional<AnyShapeOrShiftType>:$shape,
8989
Variadic<AnyIntegerType>:$typeparams,
9090
Builtin_StringAttr:$uniq_name,
91-
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs
91+
OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs,
92+
OptionalAttr<fir_CUDAAttributeAttr>:$cuda_attr
9293
);
9394

9495
let results = (outs AnyFortranVariable, AnyRefOrBoxLike);
@@ -101,7 +102,8 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments,
101102
let builders = [
102103
OpBuilder<(ins "mlir::Value":$memref, "llvm::StringRef":$uniq_name,
103104
CArg<"mlir::Value", "{}">:$shape, CArg<"mlir::ValueRange", "{}">:$typeparams,
104-
CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs)>];
105+
CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs,
106+
CArg<"fir::CUDAAttributeAttr", "{}">:$cuda_attr)>];
105107

106108
let extraClassDeclaration = [{
107109
/// Get the variable original base (same as input). It lacks

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,38 @@ fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes(
15791579
return fir::FortranVariableFlagsAttr::get(mlirContext, flags);
15801580
}
15811581

1582+
fir::CUDAAttributeAttr Fortran::lower::translateSymbolCUDAAttribute(
1583+
mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
1584+
std::optional<Fortran::common::CUDADataAttr> cudaAttr =
1585+
Fortran::semantics::GetCUDADataAttr(&sym);
1586+
if (cudaAttr) {
1587+
fir::CUDAAttribute attr;
1588+
switch (*cudaAttr) {
1589+
case Fortran::common::CUDADataAttr::Constant:
1590+
attr = fir::CUDAAttribute::Constant;
1591+
break;
1592+
case Fortran::common::CUDADataAttr::Device:
1593+
attr = fir::CUDAAttribute::Device;
1594+
break;
1595+
case Fortran::common::CUDADataAttr::Managed:
1596+
attr = fir::CUDAAttribute::Managed;
1597+
break;
1598+
case Fortran::common::CUDADataAttr::Pinned:
1599+
attr = fir::CUDAAttribute::Pinned;
1600+
break;
1601+
case Fortran::common::CUDADataAttr::Shared:
1602+
attr = fir::CUDAAttribute::Shared;
1603+
break;
1604+
case Fortran::common::CUDADataAttr::Texture:
1605+
// Obsolete attribute
1606+
break;
1607+
}
1608+
1609+
return fir::CUDAAttributeAttr::get(mlirContext, attr);
1610+
}
1611+
return {};
1612+
}
1613+
15821614
/// Map a symbol to its FIR address and evaluated specification expressions.
15831615
/// Not for symbols lowered to fir.box.
15841616
/// Will optionally create fir.declare.
@@ -1618,6 +1650,8 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
16181650
auto name = converter.mangleName(sym);
16191651
fir::FortranVariableFlagsAttr attributes =
16201652
Fortran::lower::translateSymbolAttributes(builder.getContext(), sym);
1653+
fir::CUDAAttributeAttr cudaAttr =
1654+
Fortran::lower::translateSymbolCUDAAttribute(builder.getContext(), sym);
16211655

16221656
if (isCrayPointee) {
16231657
mlir::Type baseType =
@@ -1664,7 +1698,7 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter,
16641698
return;
16651699
}
16661700
auto newBase = builder.create<hlfir::DeclareOp>(
1667-
loc, base, name, shapeOrShift, lenParams, attributes);
1701+
loc, base, name, shapeOrShift, lenParams, attributes, cudaAttr);
16681702
symMap.addVariableDefinition(sym, newBase, force);
16691703
return;
16701704
}
@@ -1709,9 +1743,12 @@ void Fortran::lower::genDeclareSymbol(
17091743
fir::FortranVariableFlagsAttr attributes =
17101744
Fortran::lower::translateSymbolAttributes(
17111745
builder.getContext(), sym.GetUltimate(), extraFlags);
1746+
fir::CUDAAttributeAttr cudaAttr =
1747+
Fortran::lower::translateSymbolCUDAAttribute(builder.getContext(),
1748+
sym.GetUltimate());
17121749
auto name = converter.mangleName(sym);
17131750
hlfir::EntityWithAttributes declare =
1714-
hlfir::genDeclare(loc, builder, exv, name, attributes);
1751+
hlfir::genDeclare(loc, builder, exv, name, attributes, cudaAttr);
17151752
symMap.addVariableDefinition(sym, declare.getIfVariableInterface(), force);
17161753
return;
17171754
}

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ mlir::Value hlfir::Entity::getFirBase() const {
198198
fir::FortranVariableOpInterface
199199
hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
200200
const fir::ExtendedValue &exv, llvm::StringRef name,
201-
fir::FortranVariableFlagsAttr flags) {
201+
fir::FortranVariableFlagsAttr flags,
202+
fir::CUDAAttributeAttr cudaAttr) {
202203

203204
mlir::Value base = fir::getBase(exv);
204205
assert(fir::conformsWithPassByRef(base.getType()) &&
@@ -228,7 +229,7 @@ hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
228229
},
229230
[](const auto &) {});
230231
auto declareOp = builder.create<hlfir::DeclareOp>(
231-
loc, base, name, shapeOrShift, lenParams, flags);
232+
loc, base, name, shapeOrShift, lenParams, flags, cudaAttr);
232233
return mlir::cast<fir::FortranVariableOpInterface>(declareOp.getOperation());
233234
}
234235

flang/lib/Optimizer/Dialect/FIRAttr.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "flang/Optimizer/Dialect/FIRDialect.h"
1515
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
1616
#include "mlir/IR/AttributeSupport.h"
17+
#include "mlir/IR/Builders.h"
1718
#include "mlir/IR/BuiltinTypes.h"
1819
#include "mlir/IR/DialectImplementation.h"
1920
#include "llvm/ADT/SmallString.h"
@@ -297,5 +298,5 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
297298
void FIROpsDialect::registerAttributes() {
298299
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
299300
LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
300-
UpperBoundAttr>();
301+
UpperBoundAttr, CUDAAttributeAttr>();
301302
}

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,15 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
123123
mlir::OperationState &result, mlir::Value memref,
124124
llvm::StringRef uniq_name, mlir::Value shape,
125125
mlir::ValueRange typeparams,
126-
fir::FortranVariableFlagsAttr fortran_attrs) {
126+
fir::FortranVariableFlagsAttr fortran_attrs,
127+
fir::CUDAAttributeAttr cuda_attr) {
127128
auto nameAttr = builder.getStringAttr(uniq_name);
128129
mlir::Type inputType = memref.getType();
129130
bool hasExplicitLbs = hasExplicitLowerBounds(shape);
130131
mlir::Type hlfirVariableType =
131132
getHLFIRVariableType(inputType, hasExplicitLbs);
132133
build(builder, result, {hlfirVariableType, inputType}, memref, shape,
133-
typeparams, nameAttr, fortran_attrs);
134+
typeparams, nameAttr, fortran_attrs, cuda_attr);
134135
}
135136

136137
mlir::LogicalResult hlfir::DeclareOp::verify() {

flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,16 @@ class DeclareOpConversion : public mlir::OpRewritePattern<hlfir::DeclareOp> {
320320
mlir::Location loc = declareOp->getLoc();
321321
mlir::Value memref = declareOp.getMemref();
322322
fir::FortranVariableFlagsAttr fortranAttrs;
323+
fir::CUDAAttributeAttr cudaAttr;
323324
if (auto attrs = declareOp.getFortranAttrs())
324325
fortranAttrs =
325326
fir::FortranVariableFlagsAttr::get(rewriter.getContext(), *attrs);
327+
if (auto attr = declareOp.getCudaAttr())
328+
cudaAttr = fir::CUDAAttributeAttr::get(rewriter.getContext(), *attr);
326329
auto firDeclareOp = rewriter.create<fir::DeclareOp>(
327330
loc, memref.getType(), memref, declareOp.getShape(),
328-
declareOp.getTypeparams(), declareOp.getUniqName(), fortranAttrs);
331+
declareOp.getTypeparams(), declareOp.getUniqName(), fortranAttrs,
332+
cudaAttr);
329333

330334
// Propagate other attributes from hlfir.declare to fir.declare.
331335
// OpenACC's acc.declare is one example. Right now, the propagation
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
! RUN: bbc -emit-hlfir -fcuda %s -o - | fir-opt -convert-hlfir-to-fir | FileCheck %s --check-prefix=FIR
3+
4+
! Test lowering of CUDA attribute on local variables.
5+
6+
subroutine local_var_attrs
7+
real, constant :: rc
8+
real, device :: rd
9+
real, allocatable, managed :: rm
10+
real, allocatable, pinned :: rp
11+
end subroutine
12+
13+
! CHECK-LABEL: func.func @_QPlocal_var_attrs()
14+
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
15+
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
16+
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
17+
! CHECK: %{{.*}}:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
18+
19+
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<constant>, uniq_name = "_QFlocal_var_attrsErc"} : (!fir.ref<f32>) -> !fir.ref<f32>
20+
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QFlocal_var_attrsErd"} : (!fir.ref<f32>) -> !fir.ref<f32>
21+
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErm"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>
22+
! FIR: %{{.*}} = fir.declare %{{.*}} {cuda_attr = #fir.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFlocal_var_attrsErp"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> !fir.ref<!fir.box<!fir.heap<f32>>>

flang/unittests/Optimizer/FortranVariableTest.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ TEST_F(FortranVariableTest, SimpleScalar) {
4949
auto name = mlir::StringAttr::get(&context, "x");
5050
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
5151
/*shape=*/mlir::Value{}, /*typeParams=*/std::nullopt, name,
52-
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{});
52+
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
53+
/*cuda_attr=*/fir::CUDAAttributeAttr{});
5354

5455
fir::FortranVariableOpInterface fortranVariable = declare;
5556
EXPECT_FALSE(fortranVariable.isArray());
@@ -74,7 +75,8 @@ TEST_F(FortranVariableTest, CharacterScalar) {
7475
auto name = mlir::StringAttr::get(&context, "x");
7576
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
7677
/*shape=*/mlir::Value{}, typeParams, name,
77-
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{});
78+
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
79+
/*cuda_attr=*/fir::CUDAAttributeAttr{});
7880

7981
fir::FortranVariableOpInterface fortranVariable = declare;
8082
EXPECT_FALSE(fortranVariable.isArray());
@@ -104,7 +106,8 @@ TEST_F(FortranVariableTest, SimpleArray) {
104106
auto name = mlir::StringAttr::get(&context, "x");
105107
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
106108
shape, /*typeParams*/ std::nullopt, name,
107-
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{});
109+
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
110+
/*cuda_attr=*/fir::CUDAAttributeAttr{});
108111

109112
fir::FortranVariableOpInterface fortranVariable = declare;
110113
EXPECT_TRUE(fortranVariable.isArray());
@@ -134,7 +137,8 @@ TEST_F(FortranVariableTest, CharacterArray) {
134137
auto name = mlir::StringAttr::get(&context, "x");
135138
auto declare = builder->create<fir::DeclareOp>(loc, addr.getType(), addr,
136139
shape, typeParams, name,
137-
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{});
140+
/*fortran_attrs=*/fir::FortranVariableFlagsAttr{},
141+
/*cuda_attr=*/fir::CUDAAttributeAttr{});
138142

139143
fir::FortranVariableOpInterface fortranVariable = declare;
140144
EXPECT_TRUE(fortranVariable.isArray());

0 commit comments

Comments
 (0)