Skip to content

Commit d055546

Browse files
authored
Merge pull request #526 from clementval/flang/openacc/lower/op/enter_data
[flang][openacc] Lower enter data directive
2 parents b1c4b28 + 4667b0b commit d055546

File tree

2 files changed

+169
-3
lines changed

2 files changed

+169
-3
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,105 @@ genACC(Fortran::lower::AbstractConverter &converter,
528528
}
529529
}
530530

531+
static void
532+
genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
533+
const Fortran::parser::AccClauseList &accClauseList) {
534+
mlir::Value ifCond, async, waitDevnum;
535+
SmallVector<Value, 2> copyinOperands, createOperands, createZeroOperands,
536+
attachOperands, waitOperands;
537+
538+
// Async, wait and self clause have optional values but can be present with
539+
// no value as well. When there is no value, the op has an attribute to
540+
// represent the clause.
541+
bool addAsyncAttr = false;
542+
bool addWaitAttr = false;
543+
544+
auto &firOpBuilder = converter.getFirOpBuilder();
545+
auto currentLocation = converter.getCurrentLocation();
546+
547+
// Lower clauses values mapped to operands.
548+
// Keep track of each group of operands separatly as clauses can appear
549+
// more than once.
550+
for (const auto &clause : accClauseList.v) {
551+
if (const auto *ifClause =
552+
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
553+
mlir::Value cond = fir::getBase(
554+
converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v)));
555+
ifCond = firOpBuilder.createConvert(currentLocation,
556+
firOpBuilder.getI1Type(), cond);
557+
} else if (const auto *asyncClause =
558+
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
559+
const auto &asyncClauseValue = asyncClause->v;
560+
if (asyncClauseValue) { // async has a value.
561+
async = fir::getBase(converter.genExprValue(
562+
*Fortran::semantics::GetExpr(*asyncClauseValue)));
563+
} else {
564+
addAsyncAttr = true;
565+
}
566+
} else if (const auto *waitClause =
567+
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
568+
const auto &waitClauseValue = waitClause->v;
569+
if (waitClauseValue) { // wait has a value.
570+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
571+
const std::list<Fortran::parser::ScalarIntExpr> &waitList =
572+
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
573+
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
574+
mlir::Value v = fir::getBase(
575+
converter.genExprValue(*Fortran::semantics::GetExpr(value)));
576+
waitOperands.push_back(v);
577+
}
578+
579+
const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue =
580+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
581+
if (waitDevnumValue)
582+
waitDevnum = fir::getBase(converter.genExprValue(
583+
*Fortran::semantics::GetExpr(*waitDevnumValue)));
584+
} else {
585+
addWaitAttr = true;
586+
}
587+
} else if (const auto *copyinClause =
588+
std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
589+
const Fortran::parser::AccObjectListWithModifier &listWithModifier =
590+
copyinClause->v;
591+
const Fortran::parser::AccObjectList &accObjectList =
592+
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
593+
genObjectList(accObjectList, converter, copyinOperands);
594+
} else if (const auto *createClause =
595+
std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
596+
genObjectListWithModifier<Fortran::parser::AccClause::Create>(
597+
createClause, converter,
598+
Fortran::parser::AccDataModifier::Modifier::Zero, createZeroOperands,
599+
createOperands);
600+
} else if (const auto *attachClause =
601+
std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
602+
genObjectList(attachClause->v, converter, attachOperands);
603+
} else {
604+
llvm::report_fatal_error(
605+
"Unknown clause in ENTER DATA directive lowering");
606+
}
607+
}
608+
609+
// Prepare the operand segement size attribute and the operands value range.
610+
SmallVector<mlir::Value, 16> operands;
611+
SmallVector<int32_t, 8> operandSegments;
612+
addOperand(operands, operandSegments, ifCond);
613+
addOperand(operands, operandSegments, async);
614+
addOperand(operands, operandSegments, waitDevnum);
615+
addOperands(operands, operandSegments, waitOperands);
616+
addOperands(operands, operandSegments, copyinOperands);
617+
addOperands(operands, operandSegments, createOperands);
618+
addOperands(operands, operandSegments, createZeroOperands);
619+
addOperands(operands, operandSegments, attachOperands);
620+
621+
auto enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>(
622+
firOpBuilder, currentLocation, operands, operandSegments);
623+
624+
if (addAsyncAttr)
625+
enterDataOp.asyncAttr(firOpBuilder.getUnitAttr());
626+
if (addWaitAttr)
627+
enterDataOp.waitAttr(firOpBuilder.getUnitAttr());
628+
}
629+
531630
static void
532631
genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
533632
const Fortran::parser::AccClauseList &accClauseList) {
@@ -605,8 +704,8 @@ genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
605704
}
606705

607706
// Prepare the operand segement size attribute and the operands value range.
608-
SmallVector<Value, 8> operands;
609-
SmallVector<int32_t, 8> operandSegments;
707+
SmallVector<mlir::Value, 14> operands;
708+
SmallVector<int32_t, 7> operandSegments;
610709
addOperand(operands, operandSegments, ifCond);
611710
addOperand(operands, operandSegments, async);
612711
addOperand(operands, operandSegments, waitDevnum);
@@ -636,7 +735,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
636735
std::get<Fortran::parser::AccClauseList>(standaloneConstruct.t);
637736

638737
if (standaloneDirective.v == llvm::acc::Directive::ACCD_enter_data) {
639-
TODO("OpenACC enter data directive not lowered yet!");
738+
genACCEnterDataOp(converter, accClauseList);
640739
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_exit_data) {
641740
genACCExitDataOp(converter, accClauseList);
642741
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_init) {
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
! This test checks lowering of OpenACC enter data directive.
2+
3+
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
4+
5+
subroutine acc_enter_data
6+
integer :: async = 1
7+
real, dimension(10, 10) :: a, b, c
8+
logical :: ifCondition = .TRUE.
9+
10+
!CHECK: [[A:%.*]] = fir.alloca !fir.array<10x10xf32> {name = "{{.*}}Ea"}
11+
!CHECK: [[B:%.*]] = fir.alloca !fir.array<10x10xf32> {name = "{{.*}}Eb"}
12+
!CHECK: [[C:%.*]] = fir.alloca !fir.array<10x10xf32> {name = "{{.*}}Ec"}
13+
14+
!$acc enter data create(a)
15+
!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
16+
17+
!$acc enter data create(a) if(.true.)
18+
!CHECK: [[IF1:%.*]] = constant true
19+
!CHECK: acc.enter_data if([[IF1]]) create([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
20+
21+
!$acc enter data create(a) if(ifCondition)
22+
!CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
23+
!CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
24+
!CHECK: acc.enter_data if([[IF2]]) create([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
25+
26+
!$acc enter data create(a) create(b) create(c)
27+
!CHECK: acc.enter_data create([[A]], [[B]], [[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>){{$}}
28+
29+
!$acc enter data create(a) create(b) create(zero: c)
30+
!CHECK: acc.enter_data create([[A]], [[B]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) create_zero([[C]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
31+
32+
!$acc enter data copyin(a) create(b) attach(c)
33+
!CHECK: acc.enter_data copyin([[A]] : !fir.ref<!fir.array<10x10xf32>>) create([[B]] : !fir.ref<!fir.array<10x10xf32>>) attach([[C]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
34+
35+
!$acc enter data create(a) async
36+
!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
37+
38+
!$acc enter data create(a) wait
39+
!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
40+
41+
!$acc enter data create(a) async wait
42+
!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
43+
44+
!$acc enter data create(a) async(1)
45+
!CHECK: [[ASYNC1:%.*]] = constant 1 : i32
46+
!CHECK: acc.enter_data async([[ASYNC1]] : i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
47+
48+
!$acc enter data create(a) async(async)
49+
!CHECK: [[ASYNC2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
50+
!CHECK: acc.enter_data async([[ASYNC2]] : i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
51+
52+
!$acc enter data create(a) wait(1)
53+
!CHECK: [[WAIT1:%.*]] = constant 1 : i32
54+
!CHECK: acc.enter_data wait([[WAIT1]] : i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
55+
56+
!$acc enter data create(a) wait(queues: 1, 2)
57+
!CHECK: [[WAIT2:%.*]] = constant 1 : i32
58+
!CHECK: [[WAIT3:%.*]] = constant 2 : i32
59+
!CHECK: acc.enter_data wait([[WAIT2]], [[WAIT3]] : i32, i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
60+
61+
!$acc enter data create(a) wait(devnum: 1: queues: 1, 2)
62+
!CHECK: [[WAIT4:%.*]] = constant 1 : i32
63+
!CHECK: [[WAIT5:%.*]] = constant 2 : i32
64+
!CHECK: [[WAIT6:%.*]] = constant 1 : i32
65+
!CHECK: acc.enter_data wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
66+
67+
end subroutine acc_enter_data

0 commit comments

Comments
 (0)