Skip to content

[mlir][openacc] Add device_type support for compute operations #75864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 84 additions & 22 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
return mlir::acc::DeviceType::Multicore;
}
return mlir::acc::DeviceType::Default;
return mlir::acc::DeviceType::None;
}

static void gatherDeviceTypeAttrs(
Expand Down Expand Up @@ -1781,61 +1781,90 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
bool outerCombined = false) {

// Parallel operation operands
mlir::Value async;
mlir::Value numWorkers;
mlir::Value vectorLength;
mlir::Value ifCond;
mlir::Value selfCond;
mlir::Value waitDevnum;
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
dataClauseOperands, numGangs;
dataClauseOperands, numGangs, numWorkers, vectorLength, async;
llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;

llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;

// Async, wait and self clause have optional values but can be present with
// Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
// represent the clause.
bool addAsyncAttr = false;
bool addWaitAttr = false;
bool addSelfAttr = false;

bool hasDefaultNone = false;
bool hasDefaultPresent = false;

fir::FirOpBuilder &builder = converter.getFirOpBuilder();

// device_type attribute is set to `none` until a device_type clause is
// encountered.
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
builder.getContext(), mlir::acc::DeviceType::None);

// Lower clauses values mapped to operands.
// Keep track of each group of operands separatly as clauses can appear
// more than once.
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
if (const auto *asyncClause =
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
const auto &asyncClauseValue = asyncClause->v;
if (asyncClauseValue) { // async has a value.
async.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
asyncDeviceTypes.push_back(crtDeviceTypeAttr);
} else {
asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
addWaitAttr, stmtCtx);
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
auto crtWaitOperands = waitOperands.size();
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
waitOperands.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}
waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
} else {
waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
auto crtNumGangs = numGangs.size();
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
numGangs.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(expr), stmtCtx)));
numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
numGangsSegments.push_back(numGangs.size() - crtNumGangs);
} else if (const auto *numWorkersClause =
std::get_if<Fortran::parser::AccClause::NumWorkers>(
&clause.u)) {
numWorkers = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
numWorkers.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto *vectorLengthClause =
std::get_if<Fortran::parser::AccClause::VectorLength>(
&clause.u)) {
vectorLength = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
vectorLength.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto *ifClause =
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
Expand Down Expand Up @@ -1986,18 +2015,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
else if ((defaultClause->v).v ==
llvm::acc::DefaultValue::ACC_Default_present)
hasDefaultPresent = true;
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
deviceTypeClause->v;
assert(deviceTypeExprList.v.size() == 1 &&
"expect only one device_type expr");
crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
}
}

// Prepare the operand segment size attribute and the operands value range.
llvm::SmallVector<mlir::Value, 8> operands;
llvm::SmallVector<int32_t, 8> operandSegments;
addOperand(operands, operandSegments, async);
addOperands(operands, operandSegments, async);
addOperands(operands, operandSegments, waitOperands);
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
addOperands(operands, operandSegments, numGangs);
addOperand(operands, operandSegments, numWorkers);
addOperand(operands, operandSegments, vectorLength);
addOperands(operands, operandSegments, numWorkers);
addOperands(operands, operandSegments, vectorLength);
}
addOperand(operands, operandSegments, ifCond);
addOperand(operands, operandSegments, selfCond);
Expand All @@ -2018,10 +2056,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
builder, currentLocation, eval, operands, operandSegments,
outerCombined);

if (addAsyncAttr)
computeOp.setAsyncAttrAttr(builder.getUnitAttr());
if (addWaitAttr)
computeOp.setWaitAttrAttr(builder.getUnitAttr());
if (addSelfAttr)
computeOp.setSelfAttrAttr(builder.getUnitAttr());

Expand All @@ -2030,6 +2064,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (hasDefaultPresent)
computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);

if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
if (!numWorkersDeviceTypes.empty())
computeOp.setNumWorkersDeviceTypeAttr(
mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
if (!vectorLengthDeviceTypes.empty())
computeOp.setVectorLengthDeviceTypeAttr(
mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
if (!numGangsDeviceTypes.empty())
computeOp.setNumGangsDeviceTypeAttr(
mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
if (!numGangsSegments.empty())
computeOp.setNumGangsSegmentsAttr(
builder.getDenseI32ArrayAttr(numGangsSegments));
}
if (!asyncDeviceTypes.empty())
computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));

if (!waitOperandsDeviceTypes.empty())
computeOp.setWaitOperandsDeviceTypeAttr(
builder.getArrayAttr(waitOperandsDeviceTypes));
if (!waitOperandsSegments.empty())
computeOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
if (!waitOnlyDeviceTypes.empty())
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));

if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
if (!privatizations.empty())
computeOp.setPrivatizationsAttr(
Expand Down
44 changes: 44 additions & 0 deletions flang/test/Lower/OpenACC/acc-device-type.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
! This test checks lowering of OpenACC device_type clause on directive where its
! position and the clauses that follow have special semantic

! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s

subroutine sub1()

!$acc parallel num_workers(16)
!$acc end parallel

! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) {

!$acc parallel num_workers(1) device_type(nvidia) num_workers(16)
!$acc end parallel

! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type<nvidia>])

!$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16)
!$acc end parallel

! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type<star>], %c16{{.*}} : i32 [#acc.device_type<nvidia>])

!$acc parallel vector_length(1)
!$acc end parallel

! CHECK: acc.parallel vector_length(%c1{{.*}} : i32)

!$acc parallel device_type(multicore) vector_length(1)
!$acc end parallel

! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type<multicore>])

!$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4)
!$acc end parallel

! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type<nvidia>])

!$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1)
!$acc end parallel

! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])


end subroutine
14 changes: 7 additions & 7 deletions flang/test/Lower/OpenACC/acc-kernels-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ subroutine acc_kernels_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {asyncAttr}
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}

!$acc kernels loop async(1)
DO i = 1, n
Expand Down Expand Up @@ -103,15 +103,15 @@ subroutine acc_kernels_loop
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {waitAttr}
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}

!$acc kernels loop wait(1)
DO i = 1, n
a(i) = b(i)
END DO

! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand All @@ -126,7 +126,7 @@ subroutine acc_kernels_loop

! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand All @@ -141,7 +141,7 @@ subroutine acc_kernels_loop

! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand All @@ -155,7 +155,7 @@ subroutine acc_kernels_loop
END DO

! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand All @@ -169,7 +169,7 @@ subroutine acc_kernels_loop
END DO

! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
! CHECK: acc.loop {
! CHECK: fir.do_loop
! CHECK: acc.yield
Expand Down
14 changes: 7 additions & 7 deletions flang/test/Lower/OpenACC/acc-kernels.f90
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ subroutine acc_kernels

! CHECK: acc.kernels {
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {asyncAttr}
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}

!$acc kernels async(1)
!$acc end kernels
Expand All @@ -63,13 +63,13 @@ subroutine acc_kernels

! CHECK: acc.kernels {
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {waitAttr}
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}

!$acc kernels wait(1)
!$acc end kernels

! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

Expand All @@ -78,7 +78,7 @@ subroutine acc_kernels

! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

Expand All @@ -87,23 +87,23 @@ subroutine acc_kernels

! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

!$acc kernels num_gangs(1)
!$acc end kernels

! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

!$acc kernels num_gangs(numGangs)
!$acc end kernels

! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
! CHECK: acc.terminator
! CHECK-NEXT: }{{$}}

Expand Down
Loading