Skip to content

Commit 5537483

Browse files
committed
Revert "[mlir][openacc] Add device_type support for compute operations (#75864)"
This reverts commit 8b885eb.
1 parent e98082d commit 5537483

File tree

15 files changed

+177
-1216
lines changed

15 files changed

+177
-1216
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 22 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
14801480
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
14811481
return mlir::acc::DeviceType::Multicore;
14821482
}
1483-
return mlir::acc::DeviceType::None;
1483+
return mlir::acc::DeviceType::Default;
14841484
}
14851485

14861486
static void gatherDeviceTypeAttrs(
@@ -1781,90 +1781,61 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
17811781
bool outerCombined = false) {
17821782

17831783
// Parallel operation operands
1784+
mlir::Value async;
1785+
mlir::Value numWorkers;
1786+
mlir::Value vectorLength;
17841787
mlir::Value ifCond;
17851788
mlir::Value selfCond;
17861789
mlir::Value waitDevnum;
17871790
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
17881791
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
1789-
dataClauseOperands, numGangs, numWorkers, vectorLength, async;
1790-
llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
1791-
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
1792-
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
1793-
llvm::SmallVector<int32_t> numGangsSegments, waitOperandsSegments;
1792+
dataClauseOperands, numGangs;
17941793

17951794
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
17961795
firstprivateOperands;
17971796
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
17981797
reductionRecipes;
17991798

1800-
// Self clause has optional values but can be present with
1799+
// Async, wait and self clause have optional values but can be present with
18011800
// no value as well. When there is no value, the op has an attribute to
18021801
// represent the clause.
1802+
bool addAsyncAttr = false;
1803+
bool addWaitAttr = false;
18031804
bool addSelfAttr = false;
18041805

18051806
bool hasDefaultNone = false;
18061807
bool hasDefaultPresent = false;
18071808

18081809
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
18091810

1810-
// device_type attribute is set to `none` until a device_type clause is
1811-
// encountered.
1812-
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
1813-
builder.getContext(), mlir::acc::DeviceType::None);
1814-
18151811
// Lower clauses values mapped to operands.
18161812
// Keep track of each group of operands separatly as clauses can appear
18171813
// more than once.
18181814
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
18191815
mlir::Location clauseLocation = converter.genLocation(clause.source);
18201816
if (const auto *asyncClause =
18211817
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
1822-
const auto &asyncClauseValue = asyncClause->v;
1823-
if (asyncClauseValue) { // async has a value.
1824-
async.push_back(fir::getBase(converter.genExprValue(
1825-
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
1826-
asyncDeviceTypes.push_back(crtDeviceTypeAttr);
1827-
} else {
1828-
asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1829-
}
1818+
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
18301819
} else if (const auto *waitClause =
18311820
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
1832-
const auto &waitClauseValue = waitClause->v;
1833-
if (waitClauseValue) { // wait has a value.
1834-
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1835-
const auto &waitList =
1836-
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1837-
auto crtWaitOperands = waitOperands.size();
1838-
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1839-
waitOperands.push_back(fir::getBase(converter.genExprValue(
1840-
*Fortran::semantics::GetExpr(value), stmtCtx)));
1841-
}
1842-
waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1843-
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
1844-
} else {
1845-
waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1846-
}
1821+
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
1822+
addWaitAttr, stmtCtx);
18471823
} else if (const auto *numGangsClause =
18481824
std::get_if<Fortran::parser::AccClause::NumGangs>(
18491825
&clause.u)) {
1850-
auto crtNumGangs = numGangs.size();
18511826
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
18521827
numGangs.push_back(fir::getBase(converter.genExprValue(
18531828
*Fortran::semantics::GetExpr(expr), stmtCtx)));
1854-
numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
1855-
numGangsSegments.push_back(numGangs.size() - crtNumGangs);
18561829
} else if (const auto *numWorkersClause =
18571830
std::get_if<Fortran::parser::AccClause::NumWorkers>(
18581831
&clause.u)) {
1859-
numWorkers.push_back(fir::getBase(converter.genExprValue(
1860-
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
1861-
numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
1832+
numWorkers = fir::getBase(converter.genExprValue(
1833+
*Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
18621834
} else if (const auto *vectorLengthClause =
18631835
std::get_if<Fortran::parser::AccClause::VectorLength>(
18641836
&clause.u)) {
1865-
vectorLength.push_back(fir::getBase(converter.genExprValue(
1866-
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
1867-
vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
1837+
vectorLength = fir::getBase(converter.genExprValue(
1838+
*Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
18681839
} else if (const auto *ifClause =
18691840
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
18701841
genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -2015,27 +1986,18 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20151986
else if ((defaultClause->v).v ==
20161987
llvm::acc::DefaultValue::ACC_Default_present)
20171988
hasDefaultPresent = true;
2018-
} else if (const auto *deviceTypeClause =
2019-
std::get_if<Fortran::parser::AccClause::DeviceType>(
2020-
&clause.u)) {
2021-
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
2022-
deviceTypeClause->v;
2023-
assert(deviceTypeExprList.v.size() == 1 &&
2024-
"expect only one device_type expr");
2025-
crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2026-
builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
20271989
}
20281990
}
20291991

20301992
// Prepare the operand segment size attribute and the operands value range.
20311993
llvm::SmallVector<mlir::Value, 8> operands;
20321994
llvm::SmallVector<int32_t, 8> operandSegments;
2033-
addOperands(operands, operandSegments, async);
1995+
addOperand(operands, operandSegments, async);
20341996
addOperands(operands, operandSegments, waitOperands);
20351997
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
20361998
addOperands(operands, operandSegments, numGangs);
2037-
addOperands(operands, operandSegments, numWorkers);
2038-
addOperands(operands, operandSegments, vectorLength);
1999+
addOperand(operands, operandSegments, numWorkers);
2000+
addOperand(operands, operandSegments, vectorLength);
20392001
}
20402002
addOperand(operands, operandSegments, ifCond);
20412003
addOperand(operands, operandSegments, selfCond);
@@ -2056,6 +2018,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20562018
builder, currentLocation, eval, operands, operandSegments,
20572019
outerCombined);
20582020

2021+
if (addAsyncAttr)
2022+
computeOp.setAsyncAttrAttr(builder.getUnitAttr());
2023+
if (addWaitAttr)
2024+
computeOp.setWaitAttrAttr(builder.getUnitAttr());
20592025
if (addSelfAttr)
20602026
computeOp.setSelfAttrAttr(builder.getUnitAttr());
20612027

@@ -2064,34 +2030,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
20642030
if (hasDefaultPresent)
20652031
computeOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
20662032

2067-
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2068-
if (!numWorkersDeviceTypes.empty())
2069-
computeOp.setNumWorkersDeviceTypeAttr(
2070-
mlir::ArrayAttr::get(builder.getContext(), numWorkersDeviceTypes));
2071-
if (!vectorLengthDeviceTypes.empty())
2072-
computeOp.setVectorLengthDeviceTypeAttr(
2073-
mlir::ArrayAttr::get(builder.getContext(), vectorLengthDeviceTypes));
2074-
if (!numGangsDeviceTypes.empty())
2075-
computeOp.setNumGangsDeviceTypeAttr(
2076-
mlir::ArrayAttr::get(builder.getContext(), numGangsDeviceTypes));
2077-
if (!numGangsSegments.empty())
2078-
computeOp.setNumGangsSegmentsAttr(
2079-
builder.getDenseI32ArrayAttr(numGangsSegments));
2080-
}
2081-
if (!asyncDeviceTypes.empty())
2082-
computeOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
2083-
if (!asyncOnlyDeviceTypes.empty())
2084-
computeOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2085-
2086-
if (!waitOperandsDeviceTypes.empty())
2087-
computeOp.setWaitOperandsDeviceTypeAttr(
2088-
builder.getArrayAttr(waitOperandsDeviceTypes));
2089-
if (!waitOperandsSegments.empty())
2090-
computeOp.setWaitOperandsSegmentsAttr(
2091-
builder.getDenseI32ArrayAttr(waitOperandsSegments));
2092-
if (!waitOnlyDeviceTypes.empty())
2093-
computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
2094-
20952033
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
20962034
if (!privatizations.empty())
20972035
computeOp.setPrivatizationsAttr(

flang/test/Lower/OpenACC/acc-device-type.f90

Lines changed: 0 additions & 44 deletions
This file was deleted.

flang/test/Lower/OpenACC/acc-kernels-loop.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ subroutine acc_kernels_loop
6262
! CHECK: acc.yield
6363
! CHECK-NEXT: }{{$}}
6464
! CHECK: acc.terminator
65-
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
65+
! CHECK-NEXT: } attributes {asyncAttr}
6666

6767
!$acc kernels loop async(1)
6868
DO i = 1, n
@@ -103,15 +103,15 @@ subroutine acc_kernels_loop
103103
! CHECK: acc.yield
104104
! CHECK-NEXT: }{{$}}
105105
! CHECK: acc.terminator
106-
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
106+
! CHECK-NEXT: } attributes {waitAttr}
107107

108108
!$acc kernels loop wait(1)
109109
DO i = 1, n
110110
a(i) = b(i)
111111
END DO
112112

113113
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
114-
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
114+
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
115115
! CHECK: acc.loop {
116116
! CHECK: fir.do_loop
117117
! CHECK: acc.yield
@@ -126,7 +126,7 @@ subroutine acc_kernels_loop
126126

127127
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
128128
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
129-
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
129+
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
130130
! CHECK: acc.loop {
131131
! CHECK: fir.do_loop
132132
! CHECK: acc.yield
@@ -141,7 +141,7 @@ subroutine acc_kernels_loop
141141

142142
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
143143
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
144-
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
144+
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
145145
! CHECK: acc.loop {
146146
! CHECK: fir.do_loop
147147
! CHECK: acc.yield
@@ -155,7 +155,7 @@ subroutine acc_kernels_loop
155155
END DO
156156

157157
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
158-
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
158+
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
159159
! CHECK: acc.loop {
160160
! CHECK: fir.do_loop
161161
! CHECK: acc.yield
@@ -169,7 +169,7 @@ subroutine acc_kernels_loop
169169
END DO
170170

171171
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
172-
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
172+
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
173173
! CHECK: acc.loop {
174174
! CHECK: fir.do_loop
175175
! CHECK: acc.yield

flang/test/Lower/OpenACC/acc-kernels.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ subroutine acc_kernels
4040

4141
! CHECK: acc.kernels {
4242
! CHECK: acc.terminator
43-
! CHECK-NEXT: } attributes {asyncOnly = [#acc.device_type<none>]}
43+
! CHECK-NEXT: } attributes {asyncAttr}
4444

4545
!$acc kernels async(1)
4646
!$acc end kernels
@@ -63,13 +63,13 @@ subroutine acc_kernels
6363

6464
! CHECK: acc.kernels {
6565
! CHECK: acc.terminator
66-
! CHECK-NEXT: } attributes {waitOnly = [#acc.device_type<none>]}
66+
! CHECK-NEXT: } attributes {waitAttr}
6767

6868
!$acc kernels wait(1)
6969
!$acc end kernels
7070

7171
! CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
72-
! CHECK: acc.kernels wait({[[WAIT1]] : i32}) {
72+
! CHECK: acc.kernels wait([[WAIT1]] : i32) {
7373
! CHECK: acc.terminator
7474
! CHECK-NEXT: }{{$}}
7575

@@ -78,7 +78,7 @@ subroutine acc_kernels
7878

7979
! CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
8080
! CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
81-
! CHECK: acc.kernels wait({[[WAIT2]] : i32, [[WAIT3]] : i32}) {
81+
! CHECK: acc.kernels wait([[WAIT2]], [[WAIT3]] : i32, i32) {
8282
! CHECK: acc.terminator
8383
! CHECK-NEXT: }{{$}}
8484

@@ -87,23 +87,23 @@ subroutine acc_kernels
8787

8888
! CHECK: [[WAIT4:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
8989
! CHECK: [[WAIT5:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
90-
! CHECK: acc.kernels wait({[[WAIT4]] : i32, [[WAIT5]] : i32}) {
90+
! CHECK: acc.kernels wait([[WAIT4]], [[WAIT5]] : i32, i32) {
9191
! CHECK: acc.terminator
9292
! CHECK-NEXT: }{{$}}
9393

9494
!$acc kernels num_gangs(1)
9595
!$acc end kernels
9696

9797
! CHECK: [[NUMGANGS1:%.*]] = arith.constant 1 : i32
98-
! CHECK: acc.kernels num_gangs({[[NUMGANGS1]] : i32}) {
98+
! CHECK: acc.kernels num_gangs([[NUMGANGS1]] : i32) {
9999
! CHECK: acc.terminator
100100
! CHECK-NEXT: }{{$}}
101101

102102
!$acc kernels num_gangs(numGangs)
103103
!$acc end kernels
104104

105105
! CHECK: [[NUMGANGS2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
106-
! CHECK: acc.kernels num_gangs({[[NUMGANGS2]] : i32}) {
106+
! CHECK: acc.kernels num_gangs([[NUMGANGS2]] : i32) {
107107
! CHECK: acc.terminator
108108
! CHECK-NEXT: }{{$}}
109109

0 commit comments

Comments
 (0)