Skip to content

[mlir][openacc][flang] Support wait devnum and clean async/wait IR #79525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 57 additions & 36 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, descTy);
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);

Expand Down Expand Up @@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
builder, loc, loadOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, loadOp.getType());
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
Expand Down Expand Up @@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
}
}

static void
genWaitClause(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::Wait *waitClause,
llvm::SmallVector<mlir::Value> &waitOperands,
llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
llvm::SmallVector<int32_t> &waitOperandsSegments,
mlir::Value &waitDevnum,
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
Fortran::lower::StatementContext &stmtCtx) {
static void genWaitClauseWithDeviceType(
Fortran::lower::AbstractConverter &converter,
const Fortran::parser::AccClause::Wait *waitClause,
llvm::SmallVector<mlir::Value> &waitOperands,
llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
llvm::SmallVector<bool> &hasDevnums,
llvm::SmallVector<int32_t> &waitOperandsSegments,
llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
Fortran::lower::StatementContext &stmtCtx) {
const auto &waitClauseValue = waitClause->v;
if (waitClauseValue) { // wait has a value.
llvm::SmallVector<mlir::Value> waitValues;

const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
const auto &waitDevnumValue =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
bool hasDevnum = false;
if (waitDevnumValue) {
waitValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)));
hasDevnum = true;
}

const auto &waitList =
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
llvm::SmallVector<mlir::Value> waitValues;
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
waitValues.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(value), stmtCtx)));
}

for (auto deviceTypeAttr : deviceTypeAttrs) {
for (auto value : waitValues)
waitOperands.push_back(value);
waitOperandsDeviceTypes.push_back(deviceTypeAttr);
waitOperandsSegments.push_back(waitValues.size());
hasDevnums.push_back(hasDevnum);
}

// TODO: move to device_type model.
const auto &waitDevnumValue =
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
if (waitDevnumValue)
waitDevnum = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
} else {
for (auto deviceTypeAttr : deviceTypeAttrs)
waitOnlyDeviceTypes.push_back(deviceTypeAttr);
Expand Down Expand Up @@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
llvm::SmallVector<bool> hasWaitDevnums;

llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
firstprivateOperands;
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
reductionRecipes;
mlir::Value waitDevnum; // TODO not yet implemented on compute op.

// Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
Expand Down Expand Up @@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
hasWaitDevnums, waitOperandsSegments,
crtDeviceTypes, stmtCtx);
} else if (const auto *numGangsClause =
std::get_if<Fortran::parser::AccClause::NumGangs>(
&clause.u)) {
Expand Down Expand Up @@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
builder.getDenseI32ArrayAttr(numGangsSegments));
}
if (!asyncDeviceTypes.empty())
computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
computeOp.setAsyncOperandsDeviceTypeAttr(
builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));

Expand All @@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
computeOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
if (!hasWaitDevnums.empty())
computeOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));

Expand Down Expand Up @@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<int32_t> waitOperandsSegments;
llvm::SmallVector<bool> hasWaitDevnums;

bool hasDefaultNone = false;
bool hasDefaultPresent = false;
Expand Down Expand Up @@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
hasWaitDevnums, waitOperandsSegments,
crtDeviceTypes, stmtCtx);
} else if(const auto *defaultClause =
std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
Expand All @@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<int32_t> operandSegments;
addOperand(operands, operandSegments, ifCond);
addOperands(operands, operandSegments, async);
addOperand(operands, operandSegments, waitDevnum);
addOperands(operands, operandSegments, waitOperands);
addOperands(operands, operandSegments, dataClauseOperands);

Expand All @@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
operandSegments);

if (!asyncDeviceTypes.empty())
dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
dataOp.setAsyncOperandsDeviceTypeAttr(
builder.getArrayAttr(asyncDeviceTypes));
if (!asyncOnlyDeviceTypes.empty())
dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
if (!waitOperandsDeviceTypes.empty())
Expand All @@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
if (!waitOperandsSegments.empty())
dataOp.setWaitOperandsSegmentsAttr(
builder.getDenseI32ArrayAttr(waitOperandsSegments));
if (!hasWaitDevnums.empty())
dataOp.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasWaitDevnums));
if (!waitOnlyDeviceTypes.empty())
dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));

Expand Down Expand Up @@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
return attributes.empty() ? nullptr : b.getArrayAttr(attributes);
}

static inline mlir::ArrayAttr
getBoolArrayAttr(fir::FirOpBuilder &b, llvm::SmallVector<bool> &values) {
return values.empty() ? nullptr : b.getBoolArrayAttr(values);
}

static inline mlir::DenseI32ArrayAttr
getDenseI32ArrayAttr(fir::FirOpBuilder &builder,
llvm::SmallVector<int32_t> &values) {
Expand All @@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
waitOperands, deviceTypeOperands, asyncOperands;
llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
llvm::SmallVector<bool> hasWaitDevnums;
llvm::SmallVector<int32_t> waitOperandsSegments;

fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Expand Down Expand Up @@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
crtDeviceTypes, stmtCtx);
} else if (const auto *waitClause =
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
genWaitClause(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
genWaitClauseWithDeviceType(converter, waitClause, waitOperands,
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
hasWaitDevnums, waitOperandsSegments,
crtDeviceTypes, stmtCtx);
} else if (const auto *deviceTypeClause =
std::get_if<Fortran::parser::AccClause::DeviceType>(
&clause.u)) {
Expand Down Expand Up @@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
builder.create<mlir::acc::UpdateOp>(
currentLocation, ifCond, asyncOperands,
getArrayAttr(builder, asyncOperandsDeviceTypes),
getArrayAttr(builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
getArrayAttr(builder, asyncOnlyDeviceTypes), waitOperands,
getDenseI32ArrayAttr(builder, waitOperandsSegments),
getArrayAttr(builder, waitOperandsDeviceTypes),
getBoolArrayAttr(builder, hasWaitDevnums),
getArrayAttr(builder, waitOnlyDeviceTypes), dataClauseOperands,
ifPresent);

Expand Down Expand Up @@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortranDesc, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);

Expand Down Expand Up @@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
builder, loc, addrOp, asFortran, bounds,
/*structured=*/false, /*implicit=*/true,
mlir::acc::DataClause::acc_update_device, addrOp.getType());
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
modBuilder.setInsertionPointAfter(postDeallocOp);
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Lower/OpenACC/acc-data.f90
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ subroutine acc_data
!$acc data present(a) wait
!$acc end data

! CHECK: acc.data dataOperands(%{{.*}}) {
! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK: acc.data dataOperands(%{{.*}}) wait {
! CHECK: }

!$acc data present(a) wait(1)
!$acc end data
Expand All @@ -176,7 +176,7 @@ subroutine acc_data
!$acc data present(a) wait(devnum: 0: 1)
!$acc end data

! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
! CHECK: acc.data dataOperands(%{{.*}}) wait({devnum: %{{.*}} : i32, %{{.*}} : i32}) {
! CHECK: }{{$}}

!$acc data default(none)
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-kernels-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ subroutine acc_kernels_loop
a(i) = b(i)
END DO

! CHECK: acc.kernels {
! CHECK: acc.kernels wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc kernels loop wait(1)
DO i = 1, n
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-kernels.f90
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ subroutine acc_kernels
!$acc kernels wait
!$acc end kernels

! CHECK: acc.kernels {
! CHECK: acc.kernels wait {
! CHECK: acc.terminator
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc kernels wait(1)
!$acc end kernels
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-parallel-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ subroutine acc_parallel_loop
a(i) = b(i)
END DO

! CHECK: acc.parallel {
! CHECK: acc.parallel wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc parallel loop wait(1)
DO i = 1, n
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-parallel.f90
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ subroutine acc_parallel
!$acc parallel wait
!$acc end parallel

! CHECK: acc.parallel {
! CHECK: acc.parallel wait {
! CHECK: acc.yield
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc parallel wait(1)
!$acc end parallel
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-serial-loop.f90
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ subroutine acc_serial_loop
a(i) = b(i)
END DO

! CHECK: acc.serial {
! CHECK: acc.serial wait {
! CHECK: acc.loop {{.*}} {
! CHECK: acc.yield
! CHECK-NEXT: }{{$}}
! CHECK: acc.yield
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc serial loop wait(1)
DO i = 1, n
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenACC/acc-serial.f90
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ subroutine acc_serial
!$acc serial wait
!$acc end serial

! CHECK: acc.serial {
! CHECK: acc.serial wait {
! CHECK: acc.yield
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
! CHECK-NEXT: }

!$acc serial wait(1)
!$acc end serial
Expand Down
5 changes: 1 addition & 4 deletions flang/test/Lower/OpenACC/acc-update.f90
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ subroutine acc_update

!$acc update host(a) wait(devnum: 1: queues: 1, 2)
! CHECK: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
! CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
! CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
! CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
! CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update wait({devnum: %c1{{.*}} : i32, %c1{{.*}} : i32, %c2{{.*}} : i32}) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>)
! CHECK: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}

!$acc update host(a) device_type(host, nvidia) async
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/Frontend/Directive/DirectiveBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class Directive<string d> {
// allowedClauses and requiredClauses lists.

// List of allowed clauses for the directive.
list<VersionedClause> allowedClauses = [];
list<VersionedClause> allowedClauses = [];

// List of clauses that are allowed to appear only once.
list<VersionedClause> allowedOnceClauses = [];
Expand Down
Loading