@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
1480
1480
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
1481
1481
return mlir::acc::DeviceType::Multicore;
1482
1482
}
1483
- return mlir::acc::DeviceType::None ;
1483
+ return mlir::acc::DeviceType::Default ;
1484
1484
}
1485
1485
1486
1486
static void gatherDeviceTypeAttrs (
@@ -1781,89 +1781,60 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
1781
1781
bool outerCombined = false ) {
1782
1782
1783
1783
// Parallel operation operands
1784
+ mlir::Value async;
1785
+ mlir::Value numWorkers;
1786
+ mlir::Value vectorLength;
1784
1787
mlir::Value ifCond;
1785
1788
mlir::Value selfCond;
1786
1789
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
1787
1790
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
1788
- dataClauseOperands, numGangs, numWorkers, vectorLength, async;
1789
- llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
1790
- vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
1791
- waitOperandsDeviceTypes, waitOnlyDeviceTypes;
1792
- llvm::SmallVector<int32_t > numGangsSegments, waitOperandsSegments;
1791
+ dataClauseOperands, numGangs;
1793
1792
1794
1793
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
1795
1794
firstprivateOperands;
1796
1795
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
1797
1796
reductionRecipes;
1798
1797
1799
- // Self clause has optional values but can be present with
1798
+ // Async, wait and self clause have optional values but can be present with
1800
1799
// no value as well. When there is no value, the op has an attribute to
1801
1800
// represent the clause.
1801
+ bool addAsyncAttr = false ;
1802
+ bool addWaitAttr = false ;
1802
1803
bool addSelfAttr = false ;
1803
1804
1804
1805
bool hasDefaultNone = false ;
1805
1806
bool hasDefaultPresent = false ;
1806
1807
1807
1808
fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
1808
1809
1809
- // device_type attribute is set to `none` until a device_type clause is
1810
- // encountered.
1811
- auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
1812
- builder.getContext (), mlir::acc::DeviceType::None);
1813
-
1814
1810
// Lower clauses values mapped to operands.
1815
1811
// Keep track of each group of operands separatly as clauses can appear
1816
1812
// more than once.
1817
1813
for (const Fortran::parser::AccClause &clause : accClauseList.v ) {
1818
1814
mlir::Location clauseLocation = converter.genLocation (clause.source );
1819
1815
if (const auto *asyncClause =
1820
1816
std::get_if<Fortran::parser::AccClause::Async>(&clause.u )) {
1821
- const auto &asyncClauseValue = asyncClause->v ;
1822
- if (asyncClauseValue) { // async has a value.
1823
- async.push_back (fir::getBase (converter.genExprValue (
1824
- *Fortran::semantics::GetExpr (*asyncClauseValue), stmtCtx)));
1825
- asyncDeviceTypes.push_back (crtDeviceTypeAttr);
1826
- } else {
1827
- asyncOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1828
- }
1817
+ genAsyncClause (converter, asyncClause, async, addAsyncAttr, stmtCtx);
1829
1818
} else if (const auto *waitClause =
1830
1819
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
1831
- const auto &waitClauseValue = waitClause->v ;
1832
- if (waitClauseValue) { // wait has a value.
1833
- const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1834
- const auto &waitList =
1835
- std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1836
- auto crtWaitOperands = waitOperands.size ();
1837
- for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1838
- waitOperands.push_back (fir::getBase (converter.genExprValue (
1839
- *Fortran::semantics::GetExpr (value), stmtCtx)));
1840
- }
1841
- waitOperandsDeviceTypes.push_back (crtDeviceTypeAttr);
1842
- waitOperandsSegments.push_back (waitOperands.size () - crtWaitOperands);
1843
- } else {
1844
- waitOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1845
- }
1820
+ genWaitClause (converter, waitClause, waitOperands, waitDevnum,
1821
+ addWaitAttr, stmtCtx);
1846
1822
} else if (const auto *numGangsClause =
1847
1823
std::get_if<Fortran::parser::AccClause::NumGangs>(
1848
1824
&clause.u )) {
1849
- auto crtNumGangs = numGangs.size ();
1850
1825
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v )
1851
1826
numGangs.push_back (fir::getBase (converter.genExprValue (
1852
1827
*Fortran::semantics::GetExpr (expr), stmtCtx)));
1853
- numGangsDeviceTypes.push_back (crtDeviceTypeAttr);
1854
- numGangsSegments.push_back (numGangs.size () - crtNumGangs);
1855
1828
} else if (const auto *numWorkersClause =
1856
1829
std::get_if<Fortran::parser::AccClause::NumWorkers>(
1857
1830
&clause.u )) {
1858
- numWorkers.push_back (fir::getBase (converter.genExprValue (
1859
- *Fortran::semantics::GetExpr (numWorkersClause->v ), stmtCtx)));
1860
- numWorkersDeviceTypes.push_back (crtDeviceTypeAttr);
1831
+ numWorkers = fir::getBase (converter.genExprValue (
1832
+ *Fortran::semantics::GetExpr (numWorkersClause->v ), stmtCtx));
1861
1833
} else if (const auto *vectorLengthClause =
1862
1834
std::get_if<Fortran::parser::AccClause::VectorLength>(
1863
1835
&clause.u )) {
1864
- vectorLength.push_back (fir::getBase (converter.genExprValue (
1865
- *Fortran::semantics::GetExpr (vectorLengthClause->v ), stmtCtx)));
1866
- vectorLengthDeviceTypes.push_back (crtDeviceTypeAttr);
1836
+ vectorLength = fir::getBase (converter.genExprValue (
1837
+ *Fortran::semantics::GetExpr (vectorLengthClause->v ), stmtCtx));
1867
1838
} else if (const auto *ifClause =
1868
1839
std::get_if<Fortran::parser::AccClause::If>(&clause.u )) {
1869
1840
genIfClause (converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -2014,27 +1985,18 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2014
1985
else if ((defaultClause->v ).v ==
2015
1986
llvm::acc::DefaultValue::ACC_Default_present)
2016
1987
hasDefaultPresent = true ;
2017
- } else if (const auto *deviceTypeClause =
2018
- std::get_if<Fortran::parser::AccClause::DeviceType>(
2019
- &clause.u )) {
2020
- const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
2021
- deviceTypeClause->v ;
2022
- assert (deviceTypeExprList.v .size () == 1 &&
2023
- " expect only one device_type expr" );
2024
- crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
2025
- builder.getContext (), getDeviceType (deviceTypeExprList.v .front ().v ));
2026
1988
}
2027
1989
}
2028
1990
2029
1991
// Prepare the operand segment size attribute and the operands value range.
2030
1992
llvm::SmallVector<mlir::Value, 8 > operands;
2031
1993
llvm::SmallVector<int32_t , 8 > operandSegments;
2032
- addOperands (operands, operandSegments, async);
1994
+ addOperand (operands, operandSegments, async);
2033
1995
addOperands (operands, operandSegments, waitOperands);
2034
1996
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2035
1997
addOperands (operands, operandSegments, numGangs);
2036
- addOperands (operands, operandSegments, numWorkers);
2037
- addOperands (operands, operandSegments, vectorLength);
1998
+ addOperand (operands, operandSegments, numWorkers);
1999
+ addOperand (operands, operandSegments, vectorLength);
2038
2000
}
2039
2001
addOperand (operands, operandSegments, ifCond);
2040
2002
addOperand (operands, operandSegments, selfCond);
@@ -2055,6 +2017,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2055
2017
builder, currentLocation, eval, operands, operandSegments,
2056
2018
outerCombined);
2057
2019
2020
+ if (addAsyncAttr)
2021
+ computeOp.setAsyncAttrAttr (builder.getUnitAttr ());
2022
+ if (addWaitAttr)
2023
+ computeOp.setWaitAttrAttr (builder.getUnitAttr ());
2058
2024
if (addSelfAttr)
2059
2025
computeOp.setSelfAttrAttr (builder.getUnitAttr ());
2060
2026
@@ -2063,34 +2029,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2063
2029
if (hasDefaultPresent)
2064
2030
computeOp.setDefaultAttr (mlir::acc::ClauseDefaultValue::Present);
2065
2031
2066
- if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2067
- if (!numWorkersDeviceTypes.empty ())
2068
- computeOp.setNumWorkersDeviceTypeAttr (
2069
- mlir::ArrayAttr::get (builder.getContext (), numWorkersDeviceTypes));
2070
- if (!vectorLengthDeviceTypes.empty ())
2071
- computeOp.setVectorLengthDeviceTypeAttr (
2072
- mlir::ArrayAttr::get (builder.getContext (), vectorLengthDeviceTypes));
2073
- if (!numGangsDeviceTypes.empty ())
2074
- computeOp.setNumGangsDeviceTypeAttr (
2075
- mlir::ArrayAttr::get (builder.getContext (), numGangsDeviceTypes));
2076
- if (!numGangsSegments.empty ())
2077
- computeOp.setNumGangsSegmentsAttr (
2078
- builder.getDenseI32ArrayAttr (numGangsSegments));
2079
- }
2080
- if (!asyncDeviceTypes.empty ())
2081
- computeOp.setAsyncDeviceTypeAttr (builder.getArrayAttr (asyncDeviceTypes));
2082
- if (!asyncOnlyDeviceTypes.empty ())
2083
- computeOp.setAsyncOnlyAttr (builder.getArrayAttr (asyncOnlyDeviceTypes));
2084
-
2085
- if (!waitOperandsDeviceTypes.empty ())
2086
- computeOp.setWaitOperandsDeviceTypeAttr (
2087
- builder.getArrayAttr (waitOperandsDeviceTypes));
2088
- if (!waitOperandsSegments.empty ())
2089
- computeOp.setWaitOperandsSegmentsAttr (
2090
- builder.getDenseI32ArrayAttr (waitOperandsSegments));
2091
- if (!waitOnlyDeviceTypes.empty ())
2092
- computeOp.setWaitOnlyAttr (builder.getArrayAttr (waitOnlyDeviceTypes));
2093
-
2094
2032
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2095
2033
if (!privatizations.empty ())
2096
2034
computeOp.setPrivatizationsAttr (
0 commit comments