Skip to content

[mlir][flang][openacc] Device type support on acc routine op #78375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 117 additions & 21 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3469,6 +3469,72 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
llvm_unreachable("unsupported declarative directive");
}

static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr,
mlir::acc::DeviceType deviceType) {
for (auto attr : arrayAttr) {
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
if (deviceTypeAttr.getValue() == deviceType)
return true;
}
return false;
}

template <typename RetTy, typename AttrTy>
static std::optional<RetTy>
getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
llvm::SmallVector<mlir::Attribute> &deviceTypes,
mlir::acc::DeviceType deviceType) {
assert(attributes.size() == deviceTypes.size() &&
"expect same number of attributes");
for (auto it : llvm::enumerate(deviceTypes)) {
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value());
if (deviceTypeAttr.getValue() == deviceType) {
if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]);
return strAttr.getValue();
} else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
auto intAttr =
mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index()]);
return intAttr.getInt();
}
}
}
return std::nullopt;
}

static bool compareDeviceTypeInfo(
mlir::acc::RoutineOp op,
llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
for (uint32_t dtypeInt = 0;
dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
if (op.getBindNameValue(dtype) !=
getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
return false;
if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
return false;
if (op.getGangDimValue(dtype) !=
getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>(
gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
return false;
if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype))
return false;
if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype))
return false;
if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype))
return false;
}
return true;
}

static void attachRoutineInfo(mlir::func::FuncOp func,
mlir::SymbolRefAttr routineAttr) {
llvm::SmallVector<mlir::SymbolRefAttr> routines;
Expand Down Expand Up @@ -3518,17 +3584,23 @@ void Fortran::lower::genOpenACCRoutineConstruct(
funcName = funcOp.getName();
}
}
bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
hasNohost = false;
std::optional<std::string> bindName = std::nullopt;
std::optional<int64_t> gangDim = std::nullopt;
bool hasNohost = false;

llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
gangDimDeviceTypes, gangDimValues;

// device_type attribute is set to `none` until a device_type clause is
// encountered.
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
builder.getContext(), mlir::acc::DeviceType::None);

for (const Fortran::parser::AccClause &clause : clauses.v) {
if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
hasSeq = true;
seqDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto *gangClause =
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
hasGang = true;

if (gangClause->v) {
const Fortran::parser::AccGangArgList &x = *gangClause->v;
for (const Fortran::parser::AccGangArg &gangArg : x.v) {
Expand All @@ -3539,21 +3611,27 @@ void Fortran::lower::genOpenACCRoutineConstruct(
if (!dimValue)
mlir::emitError(loc,
"dim value must be a constant positive integer");
gangDim = *dimValue;
gangDimValues.push_back(
builder.getIntegerAttr(builder.getI64Type(), *dimValue));
gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
}
}
} else {
gangDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
hasVector = true;
vectorDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
hasWorker = true;
workerDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
hasNohost = true;
} else if (const auto *bindClause =
std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
if (const auto *name =
std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
bindName = converter.mangleName(*name->symbol);
bindNames.push_back(
builder.getStringAttr(converter.mangleName(*name->symbol)));
bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto charExpr =
std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
&bindClause->v.u)) {
Expand All @@ -3562,8 +3640,18 @@ void Fortran::lower::genOpenACCRoutineConstruct(
*charExpr);
if (!name)
mlir::emitError(loc, "Could not retrieve the bind name");
bindName = *name;
bindNames.push_back(builder.getStringAttr(*name));
bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
deviceTypeClause->v;
assert(deviceTypeExprList.v.size() == 1 &&
"expect only one device_type expr");
crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
}
}

Expand All @@ -3575,23 +3663,31 @@ void Fortran::lower::genOpenACCRoutineConstruct(
if (routineOp.getFuncName().str().compare(funcName) == 0) {
// If the routine is already specified with the same clauses, just skip
// the operation creation.
if (routineOp.getBindName() == bindName &&
routineOp.getGang() == hasGang &&
routineOp.getWorker() == hasWorker &&
routineOp.getVector() == hasVector && routineOp.getSeq() == hasSeq &&
routineOp.getNohost() == hasNohost &&
routineOp.getGangDim() == gangDim)
if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
gangDeviceTypes, gangDimValues,
gangDimDeviceTypes, seqDeviceTypes,
workerDeviceTypes, vectorDeviceTypes) &&
routineOp.getNohost() == hasNohost)
return;
mlir::emitError(loc, "Routine already specified with different clauses");
}
}

modBuilder.create<mlir::acc::RoutineOp>(
loc, routineOpName.str(), funcName,
bindName ? builder.getStringAttr(*bindName) : mlir::StringAttr{}, hasGang,
hasWorker, hasVector, hasSeq, hasNohost, /*implicit=*/false,
gangDim ? builder.getIntegerAttr(builder.getIntegerType(32), *gangDim)
: mlir::IntegerAttr{});
bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
bindNameDeviceTypes.empty() ? nullptr
: builder.getArrayAttr(bindNameDeviceTypes),
workerDeviceTypes.empty() ? nullptr
: builder.getArrayAttr(workerDeviceTypes),
vectorDeviceTypes.empty() ? nullptr
: builder.getArrayAttr(vectorDeviceTypes),
seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
hasNohost, /*implicit=*/false,
gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
gangDimDeviceTypes.empty() ? nullptr
: builder.getArrayAttr(gangDimDeviceTypes));

if (funcOp)
attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));
Expand Down
18 changes: 16 additions & 2 deletions flang/test/Lower/OpenACC/acc-routine.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s


! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
! CHECK: acc.routine @acc_routine_10 func(@_QPacc_routine11) seq
! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq
! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a")
! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_")
! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim = 1 : i32)
! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim: 1 : i64)
! CHECK: acc.routine @acc_routine_5 func(@_QPacc_routine6) nohost
! CHECK: acc.routine @acc_routine_4 func(@_QPacc_routine5) worker
! CHECK: acc.routine @acc_routine_3 func(@_QPacc_routine4) vector
Expand Down Expand Up @@ -106,3 +108,15 @@ subroutine acc_routine14()
subroutine acc_routine15()
!$acc routine bind(acc_routine16)
end subroutine

subroutine acc_routine16()
!$acc routine device_type(host) seq dtype(nvidia) gang
end subroutine

subroutine acc_routine17()
!$acc routine device_type(host) worker dtype(multicore) vector
end subroutine

subroutine acc_routine18()
!$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16)
end subroutine
58 changes: 47 additions & 11 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1994,27 +1994,63 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {

let arguments = (ins SymbolNameAttr:$sym_name,
SymbolNameAttr:$func_name,
OptionalAttr<StrAttr>:$bind_name,
UnitAttr:$gang,
UnitAttr:$worker,
UnitAttr:$vector,
UnitAttr:$seq,
OptionalAttr<StrArrayAttr>:$bindName,
OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
OptionalAttr<DeviceTypeArrayAttr>:$worker,
OptionalAttr<DeviceTypeArrayAttr>:$vector,
OptionalAttr<DeviceTypeArrayAttr>:$seq,
UnitAttr:$nohost,
UnitAttr:$implicit,
OptionalAttr<APIntAttr>:$gangDim);
OptionalAttr<DeviceTypeArrayAttr>:$gang,
OptionalAttr<I64ArrayAttr>:$gangDim,
OptionalAttr<DeviceTypeArrayAttr>:$gangDimDeviceType);

let extraClassDeclaration = [{
static StringRef getGangDimKeyword() { return "dim"; }

/// Return true if the op has the worker attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasWorker();
/// Return true if the op has the worker attribute for the given
/// device_type.
bool hasWorker(mlir::acc::DeviceType deviceType);

/// Return true if the op has the vector attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasVector();
/// Return true if the op has the vector attribute for the given
/// device_type.
bool hasVector(mlir::acc::DeviceType deviceType);

/// Return true if the op has the seq attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasSeq();
/// Return true if the op has the seq attribute for the given
/// device_type.
bool hasSeq(mlir::acc::DeviceType deviceType);

/// Return true if the op has the gang attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasGang();
/// Return true if the op has the gang attribute for the given
/// device_type.
bool hasGang(mlir::acc::DeviceType deviceType);

std::optional<int64_t> getGangDimValue();
std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);

std::optional<llvm::StringRef> getBindNameValue();
std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
}];

let assemblyFormat = [{
$sym_name `func` `(` $func_name `)`
oilist (
`bind` `(` $bind_name `)`
| `gang` `` custom<RoutineGangClause>($gang, $gangDim)
| `worker` $worker
| `vector` $vector
| `seq` $seq
`bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
| `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
| `worker` custom<DeviceTypeArrayAttr>($worker)
| `vector` custom<DeviceTypeArrayAttr>($vector)
| `seq` custom<DeviceTypeArrayAttr>($seq)
| `nohost` $nohost
| `implicit` $implicit
) attr-dict-with-keyword
Expand Down
Loading