Skip to content

Commit b06bc7c

Browse files
authored
[mlir][flang][openacc] Device type support on acc routine op (#78375)
This patch add support for device_type on the acc.routine operation. device_type can be specified on seq, worker, vector, gang and bind information. The support is following the same design than the one for compute operations, data operation and the loop operation.
1 parent cb2f340 commit b06bc7c

File tree

6 files changed

+498
-68
lines changed

6 files changed

+498
-68
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 117 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3469,6 +3469,72 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
34693469
llvm_unreachable("unsupported declarative directive");
34703470
}
34713471

3472+
static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr,
3473+
mlir::acc::DeviceType deviceType) {
3474+
for (auto attr : arrayAttr) {
3475+
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3476+
if (deviceTypeAttr.getValue() == deviceType)
3477+
return true;
3478+
}
3479+
return false;
3480+
}
3481+
3482+
template <typename RetTy, typename AttrTy>
3483+
static std::optional<RetTy>
3484+
getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
3485+
llvm::SmallVector<mlir::Attribute> &deviceTypes,
3486+
mlir::acc::DeviceType deviceType) {
3487+
assert(attributes.size() == deviceTypes.size() &&
3488+
"expect same number of attributes");
3489+
for (auto it : llvm::enumerate(deviceTypes)) {
3490+
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value());
3491+
if (deviceTypeAttr.getValue() == deviceType) {
3492+
if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
3493+
auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]);
3494+
return strAttr.getValue();
3495+
} else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
3496+
auto intAttr =
3497+
mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index()]);
3498+
return intAttr.getInt();
3499+
}
3500+
}
3501+
}
3502+
return std::nullopt;
3503+
}
3504+
3505+
static bool compareDeviceTypeInfo(
3506+
mlir::acc::RoutineOp op,
3507+
llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
3508+
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
3509+
llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
3510+
llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
3511+
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
3512+
llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
3513+
llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
3514+
llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
3515+
for (uint32_t dtypeInt = 0;
3516+
dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
3517+
auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
3518+
if (op.getBindNameValue(dtype) !=
3519+
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
3520+
bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
3521+
return false;
3522+
if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
3523+
return false;
3524+
if (op.getGangDimValue(dtype) !=
3525+
getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>(
3526+
gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
3527+
return false;
3528+
if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype))
3529+
return false;
3530+
if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype))
3531+
return false;
3532+
if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype))
3533+
return false;
3534+
}
3535+
return true;
3536+
}
3537+
34723538
static void attachRoutineInfo(mlir::func::FuncOp func,
34733539
mlir::SymbolRefAttr routineAttr) {
34743540
llvm::SmallVector<mlir::SymbolRefAttr> routines;
@@ -3518,17 +3584,23 @@ void Fortran::lower::genOpenACCRoutineConstruct(
35183584
funcName = funcOp.getName();
35193585
}
35203586
}
3521-
bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
3522-
hasNohost = false;
3523-
std::optional<std::string> bindName = std::nullopt;
3524-
std::optional<int64_t> gangDim = std::nullopt;
3587+
bool hasNohost = false;
3588+
3589+
llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
3590+
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
3591+
gangDimDeviceTypes, gangDimValues;
3592+
3593+
// device_type attribute is set to `none` until a device_type clause is
3594+
// encountered.
3595+
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
3596+
builder.getContext(), mlir::acc::DeviceType::None);
35253597

35263598
for (const Fortran::parser::AccClause &clause : clauses.v) {
35273599
if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
3528-
hasSeq = true;
3600+
seqDeviceTypes.push_back(crtDeviceTypeAttr);
35293601
} else if (const auto *gangClause =
35303602
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
3531-
hasGang = true;
3603+
35323604
if (gangClause->v) {
35333605
const Fortran::parser::AccGangArgList &x = *gangClause->v;
35343606
for (const Fortran::parser::AccGangArg &gangArg : x.v) {
@@ -3539,21 +3611,27 @@ void Fortran::lower::genOpenACCRoutineConstruct(
35393611
if (!dimValue)
35403612
mlir::emitError(loc,
35413613
"dim value must be a constant positive integer");
3542-
gangDim = *dimValue;
3614+
gangDimValues.push_back(
3615+
builder.getIntegerAttr(builder.getI64Type(), *dimValue));
3616+
gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
35433617
}
35443618
}
3619+
} else {
3620+
gangDeviceTypes.push_back(crtDeviceTypeAttr);
35453621
}
35463622
} else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
3547-
hasVector = true;
3623+
vectorDeviceTypes.push_back(crtDeviceTypeAttr);
35483624
} else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
3549-
hasWorker = true;
3625+
workerDeviceTypes.push_back(crtDeviceTypeAttr);
35503626
} else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
35513627
hasNohost = true;
35523628
} else if (const auto *bindClause =
35533629
std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
35543630
if (const auto *name =
35553631
std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
3556-
bindName = converter.mangleName(*name->symbol);
3632+
bindNames.push_back(
3633+
builder.getStringAttr(converter.mangleName(*name->symbol)));
3634+
bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
35573635
} else if (const auto charExpr =
35583636
std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
35593637
&bindClause->v.u)) {
@@ -3562,8 +3640,18 @@ void Fortran::lower::genOpenACCRoutineConstruct(
35623640
*charExpr);
35633641
if (!name)
35643642
mlir::emitError(loc, "Could not retrieve the bind name");
3565-
bindName = *name;
3643+
bindNames.push_back(builder.getStringAttr(*name));
3644+
bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
35663645
}
3646+
} else if (const auto *deviceTypeClause =
3647+
std::get_if<Fortran::parser::AccClause::DeviceType>(
3648+
&clause.u)) {
3649+
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
3650+
deviceTypeClause->v;
3651+
assert(deviceTypeExprList.v.size() == 1 &&
3652+
"expect only one device_type expr");
3653+
crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
3654+
builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
35673655
}
35683656
}
35693657

@@ -3575,23 +3663,31 @@ void Fortran::lower::genOpenACCRoutineConstruct(
35753663
if (routineOp.getFuncName().str().compare(funcName) == 0) {
35763664
// If the routine is already specified with the same clauses, just skip
35773665
// the operation creation.
3578-
if (routineOp.getBindName() == bindName &&
3579-
routineOp.getGang() == hasGang &&
3580-
routineOp.getWorker() == hasWorker &&
3581-
routineOp.getVector() == hasVector && routineOp.getSeq() == hasSeq &&
3582-
routineOp.getNohost() == hasNohost &&
3583-
routineOp.getGangDim() == gangDim)
3666+
if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
3667+
gangDeviceTypes, gangDimValues,
3668+
gangDimDeviceTypes, seqDeviceTypes,
3669+
workerDeviceTypes, vectorDeviceTypes) &&
3670+
routineOp.getNohost() == hasNohost)
35843671
return;
35853672
mlir::emitError(loc, "Routine already specified with different clauses");
35863673
}
35873674
}
35883675

35893676
modBuilder.create<mlir::acc::RoutineOp>(
35903677
loc, routineOpName.str(), funcName,
3591-
bindName ? builder.getStringAttr(*bindName) : mlir::StringAttr{}, hasGang,
3592-
hasWorker, hasVector, hasSeq, hasNohost, /*implicit=*/false,
3593-
gangDim ? builder.getIntegerAttr(builder.getIntegerType(32), *gangDim)
3594-
: mlir::IntegerAttr{});
3678+
bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
3679+
bindNameDeviceTypes.empty() ? nullptr
3680+
: builder.getArrayAttr(bindNameDeviceTypes),
3681+
workerDeviceTypes.empty() ? nullptr
3682+
: builder.getArrayAttr(workerDeviceTypes),
3683+
vectorDeviceTypes.empty() ? nullptr
3684+
: builder.getArrayAttr(vectorDeviceTypes),
3685+
seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
3686+
hasNohost, /*implicit=*/false,
3687+
gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
3688+
gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
3689+
gangDimDeviceTypes.empty() ? nullptr
3690+
: builder.getArrayAttr(gangDimDeviceTypes));
35953691

35963692
if (funcOp)
35973693
attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));

flang/test/Lower/OpenACC/acc-routine.f90

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
44

5-
5+
! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
6+
! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
7+
! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
68
! CHECK: acc.routine @acc_routine_10 func(@_QPacc_routine11) seq
79
! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq
810
! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a")
911
! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_")
10-
! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim = 1 : i32)
12+
! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim: 1 : i64)
1113
! CHECK: acc.routine @acc_routine_5 func(@_QPacc_routine6) nohost
1214
! CHECK: acc.routine @acc_routine_4 func(@_QPacc_routine5) worker
1315
! CHECK: acc.routine @acc_routine_3 func(@_QPacc_routine4) vector
@@ -106,3 +108,15 @@ subroutine acc_routine14()
106108
subroutine acc_routine15()
107109
!$acc routine bind(acc_routine16)
108110
end subroutine
111+
112+
subroutine acc_routine16()
113+
!$acc routine device_type(host) seq dtype(nvidia) gang
114+
end subroutine
115+
116+
subroutine acc_routine17()
117+
!$acc routine device_type(host) worker dtype(multicore) vector
118+
end subroutine
119+
120+
subroutine acc_routine18()
121+
!$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16)
122+
end subroutine

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,27 +1994,63 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
19941994

19951995
let arguments = (ins SymbolNameAttr:$sym_name,
19961996
SymbolNameAttr:$func_name,
1997-
OptionalAttr<StrAttr>:$bind_name,
1998-
UnitAttr:$gang,
1999-
UnitAttr:$worker,
2000-
UnitAttr:$vector,
2001-
UnitAttr:$seq,
1997+
OptionalAttr<StrArrayAttr>:$bindName,
1998+
OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
1999+
OptionalAttr<DeviceTypeArrayAttr>:$worker,
2000+
OptionalAttr<DeviceTypeArrayAttr>:$vector,
2001+
OptionalAttr<DeviceTypeArrayAttr>:$seq,
20022002
UnitAttr:$nohost,
20032003
UnitAttr:$implicit,
2004-
OptionalAttr<APIntAttr>:$gangDim);
2004+
OptionalAttr<DeviceTypeArrayAttr>:$gang,
2005+
OptionalAttr<I64ArrayAttr>:$gangDim,
2006+
OptionalAttr<DeviceTypeArrayAttr>:$gangDimDeviceType);
20052007

20062008
let extraClassDeclaration = [{
20072009
static StringRef getGangDimKeyword() { return "dim"; }
2010+
2011+
/// Return true if the op has the worker attribute for the
2012+
/// mlir::acc::DeviceType::None device_type.
2013+
bool hasWorker();
2014+
/// Return true if the op has the worker attribute for the given
2015+
/// device_type.
2016+
bool hasWorker(mlir::acc::DeviceType deviceType);
2017+
2018+
/// Return true if the op has the vector attribute for the
2019+
/// mlir::acc::DeviceType::None device_type.
2020+
bool hasVector();
2021+
/// Return true if the op has the vector attribute for the given
2022+
/// device_type.
2023+
bool hasVector(mlir::acc::DeviceType deviceType);
2024+
2025+
/// Return true if the op has the seq attribute for the
2026+
/// mlir::acc::DeviceType::None device_type.
2027+
bool hasSeq();
2028+
/// Return true if the op has the seq attribute for the given
2029+
/// device_type.
2030+
bool hasSeq(mlir::acc::DeviceType deviceType);
2031+
2032+
/// Return true if the op has the gang attribute for the
2033+
/// mlir::acc::DeviceType::None device_type.
2034+
bool hasGang();
2035+
/// Return true if the op has the gang attribute for the given
2036+
/// device_type.
2037+
bool hasGang(mlir::acc::DeviceType deviceType);
2038+
2039+
std::optional<int64_t> getGangDimValue();
2040+
std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
2041+
2042+
std::optional<llvm::StringRef> getBindNameValue();
2043+
std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
20082044
}];
20092045

20102046
let assemblyFormat = [{
20112047
$sym_name `func` `(` $func_name `)`
20122048
oilist (
2013-
`bind` `(` $bind_name `)`
2014-
| `gang` `` custom<RoutineGangClause>($gang, $gangDim)
2015-
| `worker` $worker
2016-
| `vector` $vector
2017-
| `seq` $seq
2049+
`bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
2050+
| `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
2051+
| `worker` custom<DeviceTypeArrayAttr>($worker)
2052+
| `vector` custom<DeviceTypeArrayAttr>($vector)
2053+
| `seq` custom<DeviceTypeArrayAttr>($seq)
20182054
| `nohost` $nohost
20192055
| `implicit` $implicit
20202056
) attr-dict-with-keyword

0 commit comments

Comments
 (0)