Skip to content

Commit 74a3b0f

Browse files
authored
Merge pull request #528 from clementval/flang/openacc/lower/op/update
[flang][openacc] Lower update directive
2 parents d055546 + 52bd3ef commit 74a3b0f

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,85 @@ genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
725725
exitDataOp.finalizeAttr(firOpBuilder.getUnitAttr());
726726
}
727727

728+
static void
729+
genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
730+
const Fortran::parser::AccClauseList &accClauseList) {
731+
mlir::Value ifCond, async, waitDevnum;
732+
SmallVector<Value, 2> hostOperands, deviceOperands, waitOperands;
733+
734+
// Async and wait clause have optional values but can be present with
735+
// no value as well. When there is no value, the op has an attribute to
736+
// represent the clause.
737+
bool addAsyncAttr = false;
738+
bool addWaitAttr = false;
739+
bool addIfPresentAttr = false;
740+
741+
auto &firOpBuilder = converter.getFirOpBuilder();
742+
auto currentLocation = converter.getCurrentLocation();
743+
744+
// Lower clauses values mapped to operands.
745+
// Keep track of each group of operands separatly as clauses can appear
746+
// more than once.
747+
for (const auto &clause : accClauseList.v) {
748+
if (const auto *asyncClause =
749+
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
750+
const auto &asyncClauseValue = asyncClause->v;
751+
if (asyncClauseValue) { // async has a value.
752+
async = fir::getBase(converter.genExprValue(
753+
*Fortran::semantics::GetExpr(*asyncClauseValue)));
754+
} else {
755+
addAsyncAttr = true;
756+
}
757+
} else if (const auto *waitClause =
758+
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
759+
const auto &waitClauseValue = waitClause->v;
760+
if (waitClauseValue) { // wait has a value.
761+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
762+
const std::list<Fortran::parser::ScalarIntExpr> &waitList =
763+
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
764+
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
765+
mlir::Value v = fir::getBase(
766+
converter.genExprValue(*Fortran::semantics::GetExpr(value)));
767+
waitOperands.push_back(v);
768+
}
769+
770+
const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue =
771+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
772+
if (waitDevnumValue)
773+
waitDevnum = fir::getBase(converter.genExprValue(
774+
*Fortran::semantics::GetExpr(*waitDevnumValue)));
775+
} else {
776+
addWaitAttr = true;
777+
}
778+
} else if (const auto *hostClause =
779+
std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
780+
genObjectList(hostClause->v, converter, hostOperands);
781+
} else if (const auto *deviceClause =
782+
std::get_if<Fortran::parser::AccClause::Device>(&clause.u)) {
783+
genObjectList(deviceClause->v, converter, deviceOperands);
784+
}
785+
}
786+
787+
// Prepare the operand segement size attribute and the operands value range.
788+
SmallVector<mlir::Value, 10> operands;
789+
SmallVector<int32_t, 5> operandSegments;
790+
addOperand(operands, operandSegments, async);
791+
addOperand(operands, operandSegments, waitDevnum);
792+
addOperands(operands, operandSegments, waitOperands);
793+
addOperands(operands, operandSegments, hostOperands);
794+
addOperands(operands, operandSegments, deviceOperands);
795+
796+
auto updateOp = createSimpleOp<mlir::acc::UpdateOp>(
797+
firOpBuilder, currentLocation, operands, operandSegments);
798+
799+
if (addAsyncAttr)
800+
updateOp.asyncAttr(firOpBuilder.getUnitAttr());
801+
if (addWaitAttr)
802+
updateOp.waitAttr(firOpBuilder.getUnitAttr());
803+
if (addIfPresentAttr)
804+
updateOp.ifPresentAttr(firOpBuilder.getUnitAttr());
805+
}
806+
728807
static void
729808
genACC(Fortran::lower::AbstractConverter &converter,
730809
Fortran::lower::pft::Evaluation &eval,
@@ -745,7 +824,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
745824
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) {
746825
TODO("OpenACC set directive not lowered yet!");
747826
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) {
748-
TODO("OpenACC update directive not lowered yet!");
827+
genACCUpdateOp(converter, accClauseList);
749828
}
750829
}
751830

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
! This test checks lowering of OpenACC update directive.
2+
3+
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
4+
5+
subroutine acc_update
6+
integer :: async = 1
7+
real, dimension(10, 10) :: a, b, c
8+
logical :: ifCondition = .TRUE.
9+
10+
!CHECK: [[A:%.*]] = fir.alloca !fir.array<10x10xf32> {name = "{{.*}}Ea"}
11+
!CHECK: [[B:%.*]] = fir.alloca !fir.array<10x10xf32> {name = "{{.*}}Eb"}
12+
!CHECK: [[C:%.*]] = fir.alloca !fir.array<10x10xf32> {name = "{{.*}}Ec"}
13+
14+
!$acc update host(a)
15+
!CHECK: acc.update host([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
16+
17+
!$acc update host(a) host(b) host(c)
18+
!CHECK: acc.update host([[A]], [[B]], [[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>){{$}}
19+
20+
!$acc update host(a) host(b) device(c)
21+
!CHECK: acc.update host([[A]], [[B]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) device([[C]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
22+
23+
!$acc update host(a) async
24+
!CHECK: acc.update host([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
25+
26+
!$acc update host(a) wait
27+
!CHECK: acc.update host([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
28+
29+
!$acc update host(a) async wait
30+
!CHECK: acc.update host([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
31+
32+
!$acc update host(a) async(1)
33+
!CHECK: [[ASYNC1:%.*]] = constant 1 : i32
34+
!CHECK: acc.update async([[ASYNC1]] : i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
35+
36+
!$acc update host(a) async(async)
37+
!CHECK: [[ASYNC2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
38+
!CHECK: acc.update async([[ASYNC2]] : i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
39+
40+
!$acc update host(a) wait(1)
41+
!CHECK: [[WAIT1:%.*]] = constant 1 : i32
42+
!CHECK: acc.update wait([[WAIT1]] : i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
43+
44+
!$acc update host(a) wait(queues: 1, 2)
45+
!CHECK: [[WAIT2:%.*]] = constant 1 : i32
46+
!CHECK: [[WAIT3:%.*]] = constant 2 : i32
47+
!CHECK: acc.update wait([[WAIT2]], [[WAIT3]] : i32, i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
48+
49+
!$acc update host(a) wait(devnum: 1: queues: 1, 2)
50+
!CHECK: [[WAIT4:%.*]] = constant 1 : i32
51+
!CHECK: [[WAIT5:%.*]] = constant 2 : i32
52+
!CHECK: [[WAIT6:%.*]] = constant 1 : i32
53+
!CHECK: acc.update wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) host([[A]] : !fir.ref<!fir.array<10x10xf32>>)
54+
55+
end subroutine acc_update

0 commit comments

Comments
 (0)