Skip to content

Commit 8b885eb

Browse files
authored
[mlir][openacc] Add device_type support for compute operations (#75864)
This patch adds representation for `device_type` clause information on compute construct (parallel, kernels, serial). The `device_type` clause on compute construct impacts clauses that appear after it. The values impacted by `device_type` are now tied with an attribute array that represent the device_type associated with them. `DeviceType::None` is used to represent the value produced by a clause before any `device_type`. The operands and the attribute information are parser/printed together. This is an example with `vector_length` clause. The first value (64) is not impacted by `device_type` so it will be represented with DeviceType::None. None is not printed. The second value (128) is tied with the `device_type(multicore)` clause. ``` !$acc parallel vector_length(64) device_type(multicore) vector_length(256) ``` ``` acc.parallel vector_length(%c64 : i32, %c128 : i32 [#acc.device_type<multicore>]) { } ``` When multiple values can be produced for a single clause like `num_gangs` and `wait`, an extra attribute describe the number of values belonging to each `device_type`. Values and attributes are parsed/printed together. ``` acc.parallel num_gangs({%c2 : i32, %c4 : i32}, {%c4 : i32} [#acc.device_type<nvidia>]) ``` While preparing this patch I noticed that the wait devnum is not part of the operations and is not lowered. It will be added in a follow up patch.
1 parent 7ffad37 commit 8b885eb

File tree

15 files changed

+1216
-177
lines changed

15 files changed

+1216
-177
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
14801480
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
14811481
return mlir::acc::DeviceType::Multicore;
14821482
}
1483-
return mlir::acc::DeviceType::Default;
1483+
return mlir::acc::DeviceType::None;
14841484
}
14851485

14861486
static void gatherDeviceTypeAttrs(
@@ -1781,61 +1781,90 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
17811781
bool outerCombined = false) {
17821782

17831783
// Parallel operation operands
1784-
mlir::Value async;
1785-
mlir::Value numWorkers;
1786-
mlir::Value vectorLength;
17871784
mlir::Value ifCond;
17881785
mlir::Value selfCond;
17891786
mlir::Value waitDevnum;
17901787
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
17911788
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
1792-
dataClauseOperands, numGangs;
1789+
dataClauseOperands, numGangs, numWorkers, vectorLength, async;
1790+
llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
1791+
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
1792+
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
1793+
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
17931794

17941795
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
17951796
firstprivateOperands;
17961797
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
17971798
reductionRecipes;
17981799

1799-
// Async, wait and self clause have optional values but can be present with
1800+
// Self clause has optional values but can be present with
18001801
// no value as well. When there is no value, the op has an attribute to
18011802
// represent the clause.
1802-
bool addAsyncAttr = false;
1803-
bool addWaitAttr = false;
18041803
bool addSelfAttr = false;
18051804

18061805
bool hasDefaultNone = false;
18071806
bool hasDefaultPresent = false;
18081807

18091808
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
18101809

1810+
// device_type attribute is set to `none` until a device_type clause is
1811+
// encountered.
1812+
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
1813+
builder.getContext(), mlir::acc::DeviceType::None);
1814+
18111815
// Lower clauses values mapped to operands.
18121816
// Keep track of each group of operands separatly as clauses can appear
18131817
// more than once.
18141818
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
18151819
mlir::Location clauseLocation = converter.genLocation(clause.source);
18161820
if (const auto *asyncClause =
18171821
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
1818-
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
1822+
const auto &asyncClauseValue = asyncClause->v;
1823+
if (asyncClauseValue) { // async has a value.
1824+
async.push_back(fir::getBase(converter.genExprValue(
1825+
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
1826+
asyncDeviceTypes.push_back(crtDeviceTypeAttr);
1827+
} else {
1828+
asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1829+
}
18191830
} else if (const auto *waitClause =
18201831
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
1821-
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
1822-
addWaitAttr, stmtCtx);
1832+
const auto &waitClauseValue = waitClause->v;
1833+
if (waitClauseValue) { // wait has a value.
1834+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1835+
const auto &waitList =
1836+
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1837+
auto crtWaitOperands = waitOperands.size();
1838+
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1839+
waitOperands.push_back(fir::getBase(converter.genExprValue(
1840+
*Fortran::semantics::GetExpr(value), stmtCtx)));
1841+
}
1842+
waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1843+
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
1844+
} else {
1845+
waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1846+
}
18231847
} else if (const auto *numGangsClause =
18241848
std::get_if<Fortran::parser::AccClause::NumGangs>(
18251849
&clause.u)) {
1850+
auto crtNumGangs = numGangs.size();
18261851
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
18271852
numGangs.push_back(fir::getBase(converter.genExprValue(
18281853
*Fortran::semantics::GetExpr(expr), stmtCtx)));
1854+
numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
1855+
numGangsSegments.push_back(numGangs.size() - crtNumGangs);
18291856
} else if (const auto *numWorkersClause =
18301857
std::get_if<Fortran::parser::AccClause::NumWorkers>(
18311858
&clause.u)) {
1832-
numWorkers = fir::getBase(converter.genExprValue(
1833-
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
1859+
numWorkers.push_back(fir::getBase(converter.genExprValue(
1860+
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
1861+
numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
18341862
} else if (const auto *vectorLengthClause =
18351863
std::get_if<Fortran::parser::AccClause::VectorLength>(
18361864
&clause.u)) {
1837-
vectorLength = fir::getBase(converter.genExprValue(
1838-
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
1865+
vectorLength.push_back(fir::getBase(converter.genExprValue(
1866+
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
1867+
vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
18391868
} else if (const auto *ifClause =
18401869
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
18411870
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -1986,18 +2015,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
19862015
else if ((defaultClause->v).v ==
19872016
llvm::acc::DefaultValue::ACC_Default_present)
19882017
hasDefaultPresent = true;
2018+
} else if (const auto *deviceTypeClause =
2019+
std::get_if<Fortran::parser::AccClause::DeviceType>(
2020+
&clause.u)) {
2021+
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
2022+
deviceTypeClause->v;
2023+
assert(deviceTypeExprList.v.size() == 1 &&
2024+
"expect only one device_type expr");
2025+
crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2026+
builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
19892027
}
19902028
}
19912029

19922030
// Prepare the operand segment size attribute and the operands value range.
19932031
llvm::SmallVector<mlir::Value, 8> operands;
19942032
llvm::SmallVector<int32_t, 8> operandSegments;
1995-
addOperand(operands, operandSegments, async);
2033+
addOperands(operands, operandSegments, async);
19962034
addOperands(operands, operandSegments, waitOperands);
19972035
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
19982036
addOperands(operands, operandSegments, numGangs);
1999-
addOperand(operands, operandSegments, numWorkers);
2000-
addOperand(operands, operandSegments, vectorLength);
2037+
addOperands(operands, operandSegments, numWorkers);
2038+
addOperands(operands, operandSegments, vectorLength);
20012039
}
20022040
addOperand(operands, operandSegments, ifCond);
20032041
addOperand(operands, operandSegments, selfCond);
@@ -2018,10 +2056,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20182056
builder, currentLocation, eval, operands, operandSegments,
20192057
outerCombined);
20202058

2021-
if (addAsyncAttr)
2022-
computeOp.setAsyncAttrAttr(builder.getUnitAttr());
2023-
if (addWaitAttr)
2024-
computeOp.setWaitAttrAttr(builder.getUnitAttr());
20252059
if (addSelfAttr)
20262060
computeOp.setSelfAttrAttr(builder.getUnitAttr());
20272061

@@ -2030,6 +2064,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20302064
if (hasDefaultPresent)
20312065
computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
20322066

2067+
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2068+
if (!numWorkersDeviceTypes.empty())
2069+
computeOp.setNumWorkersDeviceTypeAttr(
2070+
mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
2071+
if (!vectorLengthDeviceTypes.empty())
2072+
computeOp.setVectorLengthDeviceTypeAttr(
2073+
mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
2074+
if (!numGangsDeviceTypes.empty())
2075+
computeOp.setNumGangsDeviceTypeAttr(
2076+
mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
2077+
if (!numGangsSegments.empty())
2078+
computeOp.setNumGangsSegmentsAttr(
2079+
builder.getDenseI32ArrayAttr(numGangsSegments));
2080+
}
2081+
if (!asyncDeviceTypes.empty())
2082+
computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
2083+
if (!asyncOnlyDeviceTypes.empty())
2084+
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2085+
2086+
if (!waitOperandsDeviceTypes.empty())
2087+
computeOp.setWaitOperandsDeviceTypeAttr(
2088+
builder.getArrayAttr(waitOperandsDeviceTypes));
2089+
if (!waitOperandsSegments.empty())
2090+
computeOp.setWaitOperandsSegmentsAttr(
2091+
builder.getDenseI32ArrayAttr(waitOperandsSegments));
2092+
if (!waitOnlyDeviceTypes.empty())
2093+
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2094+
20332095
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
20342096
if (!privatizations.empty())
20352097
computeOp.setPrivatizationsAttr(
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
! This test checks lowering of OpenACC device_type clause on directive where its
2+
! position and the clauses that follow have special semantic
3+
4+
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
5+
6+
subroutine sub1()
7+
8+
!$acc parallel num_workers(16)
9+
!$acc end parallel
10+
11+
! CHECK: acc.parallel num_workers(%c16{{.*}} : i32) {
12+
13+
!$acc parallel num_workers(1) device_type(nvidia) num_workers(16)
14+
!$acc end parallel
15+
16+
! CHECK: acc.parallel num_workers(%c1{{.*}} : i32, %c16{{.*}} : i32 [#acc.device_type<nvidia>])
17+
18+
!$acc parallel device_type(*) num_workers(1) device_type(nvidia) num_workers(16)
19+
!$acc end parallel
20+
21+
! CHECK: acc.parallel num_workers(%c1{{.*}} : i32 [#acc.device_type<star>], %c16{{.*}} : i32 [#acc.device_type<nvidia>])
22+
23+
!$acc parallel vector_length(1)
24+
!$acc end parallel
25+
26+
! CHECK: acc.parallel vector_length(%c1{{.*}} : i32)
27+
28+
!$acc parallel device_type(multicore) vector_length(1)
29+
!$acc end parallel
30+
31+
! CHECK: acc.parallel vector_length(%c1{{.*}} : i32 [#acc.device_type<multicore>])
32+
33+
!$acc parallel num_gangs(2) device_type(nvidia) num_gangs(4)
34+
!$acc end parallel
35+
36+
! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c4{{.*}} : i32} [#acc.device_type<nvidia>])
37+
38+
!$acc parallel num_gangs(2) device_type(nvidia) num_gangs(1, 1, 1)
39+
!$acc end parallel
40+
41+
! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
42+
43+
44+
end subroutine

flang/test/Lower/OpenACC/acc-kernels-loop.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
6262
! CHECK: acc.yield
6363
! CHECK-NEXT: }{{$}}
6464
! CHECK: acc.terminator
65-
! CHECK-NEXT: } attributes {asyncAttr}
65+
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
6666

6767
!$acc kernels loop async(1)
6868
DO i = 1, n
@@ -103,15 +103,15 @@ subroutine acc_kernels_loop
103103
! CHECK: acc.yield
104104
! CHECK-NEXT: }{{$}}
105105
! CHECK: acc.terminator
106-
! CHECK-NEXT: } attributes {waitAttr}
106+
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
107107

108108
!$acc kernels loop wait(1)
109109
DO i = 1, n
110110
a(i) = b(i)
111111
END DO
112112

113113
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
114-
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
114+
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
115115
! CHECK: acc.loop {
116116
! CHECK: fir.do_loop
117117
! CHECK: acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
126126

127127
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
128128
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
129-
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
129+
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
130130
! CHECK: acc.loop {
131131
! CHECK: fir.do_loop
132132
! CHECK: acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
141141

142142
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
143143
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
144-
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
144+
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
145145
! CHECK: acc.loop {
146146
! CHECK: fir.do_loop
147147
! CHECK: acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
155155
END DO
156156

157157
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
158-
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
158+
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
159159
! CHECK: acc.loop {
160160
! CHECK: fir.do_loop
161161
! CHECK: acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
169169
END DO
170170

171171
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
172-
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
172+
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
173173
! CHECK: acc.loop {
174174
! CHECK: fir.do_loop
175175
! CHECK: acc.yield

flang/test/Lower/OpenACC/acc-kernels.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ subroutine acc_kernels
4040

4141
! CHECK: acc.kernels {
4242
! CHECK: acc.terminator
43-
! CHECK-NEXT: } attributes {asyncAttr}
43+
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
4444

4545
!$acc kernels async(1)
4646
!$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
6363

6464
! CHECK: acc.kernels {
6565
! CHECK: acc.terminator
66-
! CHECK-NEXT: } attributes {waitAttr}
66+
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
6767

6868
!$acc kernels wait(1)
6969
!$acc end kernels
7070

7171
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
72-
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
72+
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
7373
! CHECK: acc.terminator
7474
! CHECK-NEXT: }{{$}}
7575

@@ -78,7 +78,7 @@ subroutine acc_kernels
7878

7979
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
8080
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
81-
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
81+
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
8282
! CHECK: acc.terminator
8383
! CHECK-NEXT: }{{$}}
8484

@@ -87,23 +87,23 @@ subroutine acc_kernels
8787

8888
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
8989
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
90-
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
90+
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
9191
! CHECK: acc.terminator
9292
! CHECK-NEXT: }{{$}}
9393

9494
!$acc kernels num_gangs(1)
9595
!$acc end kernels
9696

9797
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
98-
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
98+
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
9999
! CHECK: acc.terminator
100100
! CHECK-NEXT: }{{$}}
101101

102102
!$acc kernels num_gangs(numGangs)
103103
!$acc end kernels
104104

105105
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
106-
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
106+
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
107107
! CHECK: acc.terminator
108108
! CHECK-NEXT: }{{$}}
109109

0 commit comments

Comments
 (0)