Skip to content

Commit a25da1a

Browse files
committed
[mlir][openacc] Add device_type support for compute operations (#75864)
Re-land PR after being reverted because of buildbot failures. This patch adds representation for `device_type` clause information on compute construct (parallel, kernels, serial). The `device_type` clause on compute construct impacts clauses that appear after it. The values impacted by `device_type` are now tied with an attribute array that represent the device_type associated with them. `DeviceType::None` is used to represent the value produced by a clause before any `device_type`. The operands and the attribute information are parser/printed together. This is an example with `vector_length` clause. The first value (64) is not impacted by `device_type` so it will be represented with DeviceType::None. None is not printed. The second value (128) is tied with the `device_type(multicore)` clause. ``` !$acc parallel vector_length(64) device_type(multicore) vector_length(256) ``` ``` acc.parallel vector_length(%c64 : i32, %c128 : i32 [#acc.device_type<multicore>]) { } ``` When multiple values can be produced for a single clause like `num_gangs` and `wait`, an extra attribute describe the number of values belonging to each `device_type`. Values and attributes are parsed/printed together. ``` acc.parallel num_gangs({%c2 : i32, %c4 : i32}, {%c4 : i32} [#acc.device_type<nvidia>]) ``` While preparing this patch I noticed that the wait devnum is not part of the operations and is not lowered. It will be added in a follow up patch.
1 parent b26c0ed commit a25da1a

File tree

12 files changed

+932
-178
lines changed

12 files changed

+932
-178
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
14801480
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
14811481
return mlir::acc::DeviceType::Multicore;
14821482
}
1483-
return mlir::acc::DeviceType::Default;
1483+
return mlir::acc::DeviceType::None;
14841484
}
14851485

14861486
static void gatherDeviceTypeAttrs(
@@ -1781,61 +1781,89 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
17811781
bool outerCombined = false) {
17821782

17831783
// Parallel operation operands
1784-
mlir::Value async;
1785-
mlir::Value numWorkers;
1786-
mlir::Value vectorLength;
17871784
mlir::Value ifCond;
17881785
mlir::Value selfCond;
1789-
mlir::Value waitDevnum;
17901786
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
17911787
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
1792-
dataClauseOperands, numGangs;
1788+
dataClauseOperands, numGangs, numWorkers, vectorLength, async;
1789+
llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
1790+
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
1791+
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
1792+
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
17931793

17941794
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
17951795
firstprivateOperands;
17961796
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
17971797
reductionRecipes;
17981798

1799-
// Async, wait and self clause have optional values but can be present with
1799+
// Self clause has optional values but can be present with
18001800
// no value as well. When there is no value, the op has an attribute to
18011801
// represent the clause.
1802-
bool addAsyncAttr = false;
1803-
bool addWaitAttr = false;
18041802
bool addSelfAttr = false;
18051803

18061804
bool hasDefaultNone = false;
18071805
bool hasDefaultPresent = false;
18081806

18091807
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
18101808

1809+
// device_type attribute is set to `none` until a device_type clause is
1810+
// encountered.
1811+
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
1812+
builder.getContext(), mlir::acc::DeviceType::None);
1813+
18111814
// Lower clauses values mapped to operands.
18121815
// Keep track of each group of operands separatly as clauses can appear
18131816
// more than once.
18141817
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
18151818
mlir::Location clauseLocation = converter.genLocation(clause.source);
18161819
if (const auto *asyncClause =
18171820
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
1818-
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
1821+
const auto &asyncClauseValue = asyncClause->v;
1822+
if (asyncClauseValue) { // async has a value.
1823+
async.push_back(fir::getBase(converter.genExprValue(
1824+
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
1825+
asyncDeviceTypes.push_back(crtDeviceTypeAttr);
1826+
} else {
1827+
asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1828+
}
18191829
} else if (const auto *waitClause =
18201830
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
1821-
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
1822-
addWaitAttr, stmtCtx);
1831+
const auto &waitClauseValue = waitClause->v;
1832+
if (waitClauseValue) { // wait has a value.
1833+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1834+
const auto &waitList =
1835+
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1836+
auto crtWaitOperands = waitOperands.size();
1837+
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1838+
waitOperands.push_back(fir::getBase(converter.genExprValue(
1839+
*Fortran::semantics::GetExpr(value), stmtCtx)));
1840+
}
1841+
waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1842+
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
1843+
} else {
1844+
waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1845+
}
18231846
} else if (const auto *numGangsClause =
18241847
std::get_if<Fortran::parser::AccClause::NumGangs>(
18251848
&clause.u)) {
1849+
auto crtNumGangs = numGangs.size();
18261850
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
18271851
numGangs.push_back(fir::getBase(converter.genExprValue(
18281852
*Fortran::semantics::GetExpr(expr), stmtCtx)));
1853+
numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
1854+
numGangsSegments.push_back(numGangs.size() - crtNumGangs);
18291855
} else if (const auto *numWorkersClause =
18301856
std::get_if<Fortran::parser::AccClause::NumWorkers>(
18311857
&clause.u)) {
1832-
numWorkers = fir::getBase(converter.genExprValue(
1833-
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
1858+
numWorkers.push_back(fir::getBase(converter.genExprValue(
1859+
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
1860+
numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
18341861
} else if (const auto *vectorLengthClause =
18351862
std::get_if<Fortran::parser::AccClause::VectorLength>(
18361863
&clause.u)) {
1837-
vectorLength = fir::getBase(converter.genExprValue(
1838-
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
1864+
vectorLength.push_back(fir::getBase(converter.genExprValue(
1865+
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
1866+
vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
18391867
} else if (const auto *ifClause =
18401868
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
18411869
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -1986,18 +2014,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
19862014
else if ((defaultClause->v).v ==
19872015
llvm::acc::DefaultValue::ACC_Default_present)
19882016
hasDefaultPresent = true;
2017+
} else if (const auto *deviceTypeClause =
2018+
std::get_if<Fortran::parser::AccClause::DeviceType>(
2019+
&clause.u)) {
2020+
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
2021+
deviceTypeClause->v;
2022+
assert(deviceTypeExprList.v.size() == 1 &&
2023+
"expect only one device_type expr");
2024+
crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2025+
builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
19892026
}
19902027
}
19912028

19922029
// Prepare the operand segment size attribute and the operands value range.
19932030
llvm::SmallVector<mlir::Value, 8> operands;
19942031
llvm::SmallVector<int32_t, 8> operandSegments;
1995-
addOperand(operands, operandSegments, async);
2032+
addOperands(operands, operandSegments, async);
19962033
addOperands(operands, operandSegments, waitOperands);
19972034
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
19982035
addOperands(operands, operandSegments, numGangs);
1999-
addOperand(operands, operandSegments, numWorkers);
2000-
addOperand(operands, operandSegments, vectorLength);
2036+
addOperands(operands, operandSegments, numWorkers);
2037+
addOperands(operands, operandSegments, vectorLength);
20012038
}
20022039
addOperand(operands, operandSegments, ifCond);
20032040
addOperand(operands, operandSegments, selfCond);
@@ -2018,10 +2055,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20182055
builder, currentLocation, eval, operands, operandSegments,
20192056
outerCombined);
20202057

2021-
if (addAsyncAttr)
2022-
computeOp.setAsyncAttrAttr(builder.getUnitAttr());
2023-
if (addWaitAttr)
2024-
computeOp.setWaitAttrAttr(builder.getUnitAttr());
20252058
if (addSelfAttr)
20262059
computeOp.setSelfAttrAttr(builder.getUnitAttr());
20272060

@@ -2030,6 +2063,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20302063
if (hasDefaultPresent)
20312064
computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
20322065

2066+
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2067+
if (!numWorkersDeviceTypes.empty())
2068+
computeOp.setNumWorkersDeviceTypeAttr(
2069+
mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
2070+
if (!vectorLengthDeviceTypes.empty())
2071+
computeOp.setVectorLengthDeviceTypeAttr(
2072+
mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
2073+
if (!numGangsDeviceTypes.empty())
2074+
computeOp.setNumGangsDeviceTypeAttr(
2075+
mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
2076+
if (!numGangsSegments.empty())
2077+
computeOp.setNumGangsSegmentsAttr(
2078+
builder.getDenseI32ArrayAttr(numGangsSegments));
2079+
}
2080+
if (!asyncDeviceTypes.empty())
2081+
computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
2082+
if (!asyncOnlyDeviceTypes.empty())
2083+
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2084+
2085+
if (!waitOperandsDeviceTypes.empty())
2086+
computeOp.setWaitOperandsDeviceTypeAttr(
2087+
builder.getArrayAttr(waitOperandsDeviceTypes));
2088+
if (!waitOperandsSegments.empty())
2089+
computeOp.setWaitOperandsSegmentsAttr(
2090+
builder.getDenseI32ArrayAttr(waitOperandsSegments));
2091+
if (!waitOnlyDeviceTypes.empty())
2092+
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2093+
20332094
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
20342095
if (!privatizations.empty())
20352096
computeOp.setPrivatizationsAttr(
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
! This test checks lowering of OpenACC device_type clause on directive where its
2+
! position and the clauses that follow have special semantic
3+
4+
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
5+
6+
subroutine sub1()
7+
8+
!$acc parallel num_workers(16)
9+
!$acc end parallel
10+
11+
! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) {
12+
13+
!$acc parallel num_workers(1) device_type(nvidia) num_workers(16)
14+
!$acc end parallel
15+
16+
! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type<nvidia>])
17+
18+
!$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16)
19+
!$acc end parallel
20+
21+
! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type<star>], %c16{{.*}} : i32 [#acc.device_type<nvidia>])
22+
23+
!$acc parallel vector_length(1)
24+
!$acc end parallel
25+
26+
! CHECK: acc.parallel vector_length(%c1{{.*}} : i32)
27+
28+
!$acc parallel device_type(multicore) vector_length(1)
29+
!$acc end parallel
30+
31+
! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type<multicore>])
32+
33+
!$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4)
34+
!$acc end parallel
35+
36+
! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type<nvidia>])
37+
38+
!$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1)
39+
!$acc end parallel
40+
41+
! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
42+
43+
44+
end subroutine

flang/test/Lower/OpenACC/acc-kernels-loop.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
6262
! CHECK: acc.yield
6363
! CHECK-NEXT: }{{$}}
6464
! CHECK: acc.terminator
65-
! CHECK-NEXT: } attributes {asyncAttr}
65+
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
6666

6767
!$acc kernels loop async(1)
6868
DO i = 1, n
@@ -103,15 +103,15 @@ subroutine acc_kernels_loop
103103
! CHECK: acc.yield
104104
! CHECK-NEXT: }{{$}}
105105
! CHECK: acc.terminator
106-
! CHECK-NEXT: } attributes {waitAttr}
106+
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
107107

108108
!$acc kernels loop wait(1)
109109
DO i = 1, n
110110
a(i) = b(i)
111111
END DO
112112

113113
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
114-
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
114+
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
115115
! CHECK: acc.loop {
116116
! CHECK: fir.do_loop
117117
! CHECK: acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
126126

127127
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
128128
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
129-
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
129+
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
130130
! CHECK: acc.loop {
131131
! CHECK: fir.do_loop
132132
! CHECK: acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
141141

142142
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
143143
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
144-
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
144+
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
145145
! CHECK: acc.loop {
146146
! CHECK: fir.do_loop
147147
! CHECK: acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
155155
END DO
156156

157157
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
158-
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
158+
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
159159
! CHECK: acc.loop {
160160
! CHECK: fir.do_loop
161161
! CHECK: acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
169169
END DO
170170

171171
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
172-
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
172+
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
173173
! CHECK: acc.loop {
174174
! CHECK: fir.do_loop
175175
! CHECK: acc.yield

flang/test/Lower/OpenACC/acc-kernels.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ subroutine acc_kernels
4040

4141
! CHECK: acc.kernels {
4242
! CHECK: acc.terminator
43-
! CHECK-NEXT: } attributes {asyncAttr}
43+
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
4444

4545
!$acc kernels async(1)
4646
!$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
6363

6464
! CHECK: acc.kernels {
6565
! CHECK: acc.terminator
66-
! CHECK-NEXT: } attributes {waitAttr}
66+
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
6767

6868
!$acc kernels wait(1)
6969
!$acc end kernels
7070

7171
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
72-
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
72+
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
7373
! CHECK: acc.terminator
7474
! CHECK-NEXT: }{{$}}
7575

@@ -78,7 +78,7 @@ subroutine acc_kernels
7878

7979
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
8080
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
81-
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
81+
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
8282
! CHECK: acc.terminator
8383
! CHECK-NEXT: }{{$}}
8484

@@ -87,23 +87,23 @@ subroutine acc_kernels
8787

8888
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
8989
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
90-
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
90+
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
9191
! CHECK: acc.terminator
9292
! CHECK-NEXT: }{{$}}
9393

9494
!$acc kernels num_gangs(1)
9595
!$acc end kernels
9696

9797
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
98-
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
98+
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
9999
! CHECK: acc.terminator
100100
! CHECK-NEXT: }{{$}}
101101

102102
!$acc kernels num_gangs(numGangs)
103103
!$acc end kernels
104104

105105
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
106-
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
106+
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
107107
! CHECK: acc.terminator
108108
! CHECK-NEXT: }{{$}}
109109

0 commit comments

Comments
 (0)