Skip to content

Commit 10df608

Browse files
authored
Revert "[mlir][openacc] Add device_type support for compute operations (#75864)"
This reverts commit 8b885eb.
1 parent 7c9c807 commit 10df608

File tree

15 files changed

+177
-1216
lines changed

15 files changed

+177
-1216
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 22 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
14801480
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
14811481
return mlir::acc::DeviceType::Multicore;
14821482
}
1483-
return mlir::acc::DeviceType::None;
1483+
return mlir::acc::DeviceType::Default;
14841484
}
14851485

14861486
static void gatherDeviceTypeAttrs(
@@ -1781,89 +1781,60 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
17811781
bool outerCombined = false) {
17821782

17831783
// Parallel operation operands
1784+
mlir::Value async;
1785+
mlir::Value numWorkers;
1786+
mlir::Value vectorLength;
17841787
mlir::Value ifCond;
17851788
mlir::Value selfCond;
17861789
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
17871790
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
1788-
dataClauseOperands, numGangs, numWorkers, vectorLength, async;
1789-
llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
1790-
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
1791-
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
1792-
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
1791+
dataClauseOperands, numGangs;
17931792

17941793
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
17951794
firstprivateOperands;
17961795
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
17971796
reductionRecipes;
17981797

1799-
// Self clause has optional values but can be present with
1798+
// Async, wait and self clause have optional values but can be present with
18001799
// no value as well. When there is no value, the op has an attribute to
18011800
// represent the clause.
1801+
bool addAsyncAttr = false;
1802+
bool addWaitAttr = false;
18021803
bool addSelfAttr = false;
18031804

18041805
bool hasDefaultNone = false;
18051806
bool hasDefaultPresent = false;
18061807

18071808
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
18081809

1809-
// device_type attribute is set to `none` until a device_type clause is
1810-
// encountered.
1811-
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
1812-
builder.getContext(), mlir::acc::DeviceType::None);
1813-
18141810
// Lower clauses values mapped to operands.
18151811
// Keep track of each group of operands separatly as clauses can appear
18161812
// more than once.
18171813
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
18181814
mlir::Location clauseLocation = converter.genLocation(clause.source);
18191815
if (const auto *asyncClause =
18201816
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
1821-
const auto &asyncClauseValue = asyncClause->v;
1822-
if (asyncClauseValue) { // async has a value.
1823-
async.push_back(fir::getBase(converter.genExprValue(
1824-
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
1825-
asyncDeviceTypes.push_back(crtDeviceTypeAttr);
1826-
} else {
1827-
asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1828-
}
1817+
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
18291818
} else if (const auto *waitClause =
18301819
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
1831-
const auto &waitClauseValue = waitClause->v;
1832-
if (waitClauseValue) { // wait has a value.
1833-
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1834-
const auto &waitList =
1835-
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1836-
auto crtWaitOperands = waitOperands.size();
1837-
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1838-
waitOperands.push_back(fir::getBase(converter.genExprValue(
1839-
*Fortran::semantics::GetExpr(value), stmtCtx)));
1840-
}
1841-
waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1842-
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
1843-
} else {
1844-
waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1845-
}
1820+
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
1821+
addWaitAttr, stmtCtx);
18461822
} else if (const auto *numGangsClause =
18471823
std::get_if<Fortran::parser::AccClause::NumGangs>(
18481824
&clause.u)) {
1849-
auto crtNumGangs = numGangs.size();
18501825
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
18511826
numGangs.push_back(fir::getBase(converter.genExprValue(
18521827
*Fortran::semantics::GetExpr(expr), stmtCtx)));
1853-
numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
1854-
numGangsSegments.push_back(numGangs.size() - crtNumGangs);
18551828
} else if (const auto *numWorkersClause =
18561829
std::get_if<Fortran::parser::AccClause::NumWorkers>(
18571830
&clause.u)) {
1858-
numWorkers.push_back(fir::getBase(converter.genExprValue(
1859-
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
1860-
numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
1831+
numWorkers = fir::getBase(converter.genExprValue(
1832+
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
18611833
} else if (const auto *vectorLengthClause =
18621834
std::get_if<Fortran::parser::AccClause::VectorLength>(
18631835
&clause.u)) {
1864-
vectorLength.push_back(fir::getBase(converter.genExprValue(
1865-
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
1866-
vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
1836+
vectorLength = fir::getBase(converter.genExprValue(
1837+
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
18671838
} else if (const auto *ifClause =
18681839
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
18691840
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -2014,27 +1985,18 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20141985
else if ((defaultClause->v).v ==
20151986
llvm::acc::DefaultValue::ACC_Default_present)
20161987
hasDefaultPresent = true;
2017-
} else if (const auto *deviceTypeClause =
2018-
std::get_if<Fortran::parser::AccClause::DeviceType>(
2019-
&clause.u)) {
2020-
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
2021-
deviceTypeClause->v;
2022-
assert(deviceTypeExprList.v.size() == 1 &&
2023-
"expect only one device_type expr");
2024-
crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2025-
builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
20261988
}
20271989
}
20281990

20291991
// Prepare the operand segment size attribute and the operands value range.
20301992
llvm::SmallVector<mlir::Value, 8> operands;
20311993
llvm::SmallVector<int32_t, 8> operandSegments;
2032-
addOperands(operands, operandSegments, async);
1994+
addOperand(operands, operandSegments, async);
20331995
addOperands(operands, operandSegments, waitOperands);
20341996
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
20351997
addOperands(operands, operandSegments, numGangs);
2036-
addOperands(operands, operandSegments, numWorkers);
2037-
addOperands(operands, operandSegments, vectorLength);
1998+
addOperand(operands, operandSegments, numWorkers);
1999+
addOperand(operands, operandSegments, vectorLength);
20382000
}
20392001
addOperand(operands, operandSegments, ifCond);
20402002
addOperand(operands, operandSegments, selfCond);
@@ -2055,6 +2017,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20552017
builder, currentLocation, eval, operands, operandSegments,
20562018
outerCombined);
20572019

2020+
if (addAsyncAttr)
2021+
computeOp.setAsyncAttrAttr(builder.getUnitAttr());
2022+
if (addWaitAttr)
2023+
computeOp.setWaitAttrAttr(builder.getUnitAttr());
20582024
if (addSelfAttr)
20592025
computeOp.setSelfAttrAttr(builder.getUnitAttr());
20602026

@@ -2063,34 +2029,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20632029
if (hasDefaultPresent)
20642030
computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
20652031

2066-
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2067-
if (!numWorkersDeviceTypes.empty())
2068-
computeOp.setNumWorkersDeviceTypeAttr(
2069-
mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
2070-
if (!vectorLengthDeviceTypes.empty())
2071-
computeOp.setVectorLengthDeviceTypeAttr(
2072-
mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
2073-
if (!numGangsDeviceTypes.empty())
2074-
computeOp.setNumGangsDeviceTypeAttr(
2075-
mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
2076-
if (!numGangsSegments.empty())
2077-
computeOp.setNumGangsSegmentsAttr(
2078-
builder.getDenseI32ArrayAttr(numGangsSegments));
2079-
}
2080-
if (!asyncDeviceTypes.empty())
2081-
computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
2082-
if (!asyncOnlyDeviceTypes.empty())
2083-
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2084-
2085-
if (!waitOperandsDeviceTypes.empty())
2086-
computeOp.setWaitOperandsDeviceTypeAttr(
2087-
builder.getArrayAttr(waitOperandsDeviceTypes));
2088-
if (!waitOperandsSegments.empty())
2089-
computeOp.setWaitOperandsSegmentsAttr(
2090-
builder.getDenseI32ArrayAttr(waitOperandsSegments));
2091-
if (!waitOnlyDeviceTypes.empty())
2092-
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2093-
20942032
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
20952033
if (!privatizations.empty())
20962034
computeOp.setPrivatizationsAttr(

flang/test/Lower/OpenACC/acc-device-type.f90

Lines changed: 0 additions & 44 deletions
This file was deleted.

flang/test/Lower/OpenACC/acc-kernels-loop.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
6262
! CHECK: acc.yield
6363
! CHECK-NEXT: }{{$}}
6464
! CHECK: acc.terminator
65-
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
65+
! CHECK-NEXT: } attributes {asyncAttr}
6666

6767
!$acc kernels loop async(1)
6868
DO i = 1, n
@@ -103,15 +103,15 @@ subroutine acc_kernels_loop
103103
! CHECK: acc.yield
104104
! CHECK-NEXT: }{{$}}
105105
! CHECK: acc.terminator
106-
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
106+
! CHECK-NEXT: } attributes {waitAttr}
107107

108108
!$acc kernels loop wait(1)
109109
DO i = 1, n
110110
a(i) = b(i)
111111
END DO
112112

113113
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
114-
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
114+
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
115115
! CHECK: acc.loop {
116116
! CHECK: fir.do_loop
117117
! CHECK: acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
126126

127127
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
128128
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
129-
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
129+
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
130130
! CHECK: acc.loop {
131131
! CHECK: fir.do_loop
132132
! CHECK: acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
141141

142142
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
143143
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
144-
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
144+
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
145145
! CHECK: acc.loop {
146146
! CHECK: fir.do_loop
147147
! CHECK: acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
155155
END DO
156156

157157
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
158-
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
158+
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
159159
! CHECK: acc.loop {
160160
! CHECK: fir.do_loop
161161
! CHECK: acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
169169
END DO
170170

171171
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
172-
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
172+
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
173173
! CHECK: acc.loop {
174174
! CHECK: fir.do_loop
175175
! CHECK: acc.yield

flang/test/Lower/OpenACC/acc-kernels.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ subroutine acc_kernels
4040

4141
! CHECK: acc.kernels {
4242
! CHECK: acc.terminator
43-
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
43+
! CHECK-NEXT: } attributes {asyncAttr}
4444

4545
!$acc kernels async(1)
4646
!$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
6363

6464
! CHECK: acc.kernels {
6565
! CHECK: acc.terminator
66-
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
66+
! CHECK-NEXT: } attributes {waitAttr}
6767

6868
!$acc kernels wait(1)
6969
!$acc end kernels
7070

7171
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
72-
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
72+
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
7373
! CHECK: acc.terminator
7474
! CHECK-NEXT: }{{$}}
7575

@@ -78,7 +78,7 @@ subroutine acc_kernels
7878

7979
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
8080
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
81-
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
81+
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
8282
! CHECK: acc.terminator
8383
! CHECK-NEXT: }{{$}}
8484

@@ -87,23 +87,23 @@ subroutine acc_kernels
8787

8888
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
8989
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
90-
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
90+
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
9191
! CHECK: acc.terminator
9292
! CHECK-NEXT: }{{$}}
9393

9494
!$acc kernels num_gangs(1)
9595
!$acc end kernels
9696

9797
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
98-
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
98+
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
9999
! CHECK: acc.terminator
100100
! CHECK-NEXT: }{{$}}
101101

102102
!$acc kernels num_gangs(numGangs)
103103
!$acc end kernels
104104

105105
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
106-
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
106+
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
107107
! CHECK: acc.terminator
108108
! CHECK-NEXT: }{{$}}
109109

0 commit comments

Comments
 (0)