@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
1480
1480
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
1481
1481
return mlir::acc::DeviceType::Multicore;
1482
1482
}
1483
- return mlir::acc::DeviceType::Default ;
1483
+ return mlir::acc::DeviceType::None ;
1484
1484
}
1485
1485
1486
1486
static void gatherDeviceTypeAttrs (
@@ -1781,61 +1781,90 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
1781
1781
bool outerCombined = false ) {
1782
1782
1783
1783
// Parallel operation operands
1784
- mlir::Value async;
1785
- mlir::Value numWorkers;
1786
- mlir::Value vectorLength;
1787
1784
mlir::Value ifCond;
1788
1785
mlir::Value selfCond;
1789
1786
mlir::Value waitDevnum;
1790
1787
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
1791
1788
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
1792
- dataClauseOperands, numGangs;
1789
+ dataClauseOperands, numGangs, numWorkers, vectorLength, async;
1790
+ llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
1791
+ vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
1792
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes;
1793
+ llvm::SmallVector<int32_t > numGangsSegments, waitOperandsSegments;
1793
1794
1794
1795
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
1795
1796
firstprivateOperands;
1796
1797
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
1797
1798
reductionRecipes;
1798
1799
1799
- // Async, wait and self clause have optional values but can be present with
1800
+ // Self clause has optional values but can be present with
1800
1801
// no value as well. When there is no value, the op has an attribute to
1801
1802
// represent the clause.
1802
- bool addAsyncAttr = false ;
1803
- bool addWaitAttr = false ;
1804
1803
bool addSelfAttr = false ;
1805
1804
1806
1805
bool hasDefaultNone = false ;
1807
1806
bool hasDefaultPresent = false ;
1808
1807
1809
1808
fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
1810
1809
1810
+ // device_type attribute is set to `none` until a device_type clause is
1811
+ // encountered.
1812
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
1813
+ builder.getContext (), mlir::acc::DeviceType::None);
1814
+
1811
1815
// Lower clauses values mapped to operands.
1812
1816
// Keep track of each group of operands separatly as clauses can appear
1813
1817
// more than once.
1814
1818
for (const Fortran::parser::AccClause &clause : accClauseList.v ) {
1815
1819
mlir::Location clauseLocation = converter.genLocation (clause.source );
1816
1820
if (const auto *asyncClause =
1817
1821
std::get_if<Fortran::parser::AccClause::Async>(&clause.u )) {
1818
- genAsyncClause (converter, asyncClause, async, addAsyncAttr, stmtCtx);
1822
+ const auto &asyncClauseValue = asyncClause->v ;
1823
+ if (asyncClauseValue) { // async has a value.
1824
+ async.push_back (fir::getBase (converter.genExprValue (
1825
+ *Fortran::semantics::GetExpr (*asyncClauseValue), stmtCtx)));
1826
+ asyncDeviceTypes.push_back (crtDeviceTypeAttr);
1827
+ } else {
1828
+ asyncOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1829
+ }
1819
1830
} else if (const auto *waitClause =
1820
1831
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
1821
- genWaitClause (converter, waitClause, waitOperands, waitDevnum,
1822
- addWaitAttr, stmtCtx);
1832
+ const auto &waitClauseValue = waitClause->v ;
1833
+ if (waitClauseValue) { // wait has a value.
1834
+ const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1835
+ const auto &waitList =
1836
+ std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1837
+ auto crtWaitOperands = waitOperands.size ();
1838
+ for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1839
+ waitOperands.push_back (fir::getBase (converter.genExprValue (
1840
+ *Fortran::semantics::GetExpr (value), stmtCtx)));
1841
+ }
1842
+ waitOperandsDeviceTypes.push_back (crtDeviceTypeAttr);
1843
+ waitOperandsSegments.push_back (waitOperands.size () - crtWaitOperands);
1844
+ } else {
1845
+ waitOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1846
+ }
1823
1847
} else if (const auto *numGangsClause =
1824
1848
std::get_if<Fortran::parser::AccClause::NumGangs>(
1825
1849
&clause.u )) {
1850
+ auto crtNumGangs = numGangs.size ();
1826
1851
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v )
1827
1852
numGangs.push_back (fir::getBase (converter.genExprValue (
1828
1853
*Fortran::semantics::GetExpr (expr), stmtCtx)));
1854
+ numGangsDeviceTypes.push_back (crtDeviceTypeAttr);
1855
+ numGangsSegments.push_back (numGangs.size () - crtNumGangs);
1829
1856
} else if (const auto *numWorkersClause =
1830
1857
std::get_if<Fortran::parser::AccClause::NumWorkers>(
1831
1858
&clause.u )) {
1832
- numWorkers = fir::getBase (converter.genExprValue (
1833
- *Fortran::semantics::GetExpr (numWorkersClause->v ), stmtCtx));
1859
+ numWorkers.push_back (fir::getBase (converter.genExprValue (
1860
+ *Fortran::semantics::GetExpr (numWorkersClause->v ), stmtCtx)));
1861
+ numWorkersDeviceTypes.push_back (crtDeviceTypeAttr);
1834
1862
} else if (const auto *vectorLengthClause =
1835
1863
std::get_if<Fortran::parser::AccClause::VectorLength>(
1836
1864
&clause.u )) {
1837
- vectorLength = fir::getBase (converter.genExprValue (
1838
- *Fortran::semantics::GetExpr (vectorLengthClause->v ), stmtCtx));
1865
+ vectorLength.push_back (fir::getBase (converter.genExprValue (
1866
+ *Fortran::semantics::GetExpr (vectorLengthClause->v ), stmtCtx)));
1867
+ vectorLengthDeviceTypes.push_back (crtDeviceTypeAttr);
1839
1868
} else if (const auto *ifClause =
1840
1869
std::get_if<Fortran::parser::AccClause::If>(&clause.u )) {
1841
1870
genIfClause (converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -1986,18 +2015,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
1986
2015
else if ((defaultClause->v ).v ==
1987
2016
llvm::acc::DefaultValue::ACC_Default_present)
1988
2017
hasDefaultPresent = true ;
2018
+ } else if (const auto *deviceTypeClause =
2019
+ std::get_if<Fortran::parser::AccClause::DeviceType>(
2020
+ &clause.u )) {
2021
+ const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
2022
+ deviceTypeClause->v ;
2023
+ assert (deviceTypeExprList.v .size () == 1 &&
2024
+ " expect only one device_type expr" );
2025
+ crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
2026
+ builder.getContext (), getDeviceType (deviceTypeExprList.v .front ().v ));
1989
2027
}
1990
2028
}
1991
2029
1992
2030
// Prepare the operand segment size attribute and the operands value range.
1993
2031
llvm::SmallVector<mlir::Value, 8 > operands;
1994
2032
llvm::SmallVector<int32_t , 8 > operandSegments;
1995
- addOperand (operands, operandSegments, async);
2033
+ addOperands (operands, operandSegments, async);
1996
2034
addOperands (operands, operandSegments, waitOperands);
1997
2035
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
1998
2036
addOperands (operands, operandSegments, numGangs);
1999
- addOperand (operands, operandSegments, numWorkers);
2000
- addOperand (operands, operandSegments, vectorLength);
2037
+ addOperands (operands, operandSegments, numWorkers);
2038
+ addOperands (operands, operandSegments, vectorLength);
2001
2039
}
2002
2040
addOperand (operands, operandSegments, ifCond);
2003
2041
addOperand (operands, operandSegments, selfCond);
@@ -2018,10 +2056,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2018
2056
builder, currentLocation, eval, operands, operandSegments,
2019
2057
outerCombined);
2020
2058
2021
- if (addAsyncAttr)
2022
- computeOp.setAsyncAttrAttr (builder.getUnitAttr ());
2023
- if (addWaitAttr)
2024
- computeOp.setWaitAttrAttr (builder.getUnitAttr ());
2025
2059
if (addSelfAttr)
2026
2060
computeOp.setSelfAttrAttr (builder.getUnitAttr ());
2027
2061
@@ -2030,6 +2064,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2030
2064
if (hasDefaultPresent)
2031
2065
computeOp.setDefaultAttr (mlir::acc::ClauseDefaultValue::Present);
2032
2066
2067
+ if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2068
+ if (!numWorkersDeviceTypes.empty ())
2069
+ computeOp.setNumWorkersDeviceTypeAttr (
2070
+ mlir::ArrayAttr::get (builder.getContext (), numWorkersDeviceTypes));
2071
+ if (!vectorLengthDeviceTypes.empty ())
2072
+ computeOp.setVectorLengthDeviceTypeAttr (
2073
+ mlir::ArrayAttr::get (builder.getContext (), vectorLengthDeviceTypes));
2074
+ if (!numGangsDeviceTypes.empty ())
2075
+ computeOp.setNumGangsDeviceTypeAttr (
2076
+ mlir::ArrayAttr::get (builder.getContext (), numGangsDeviceTypes));
2077
+ if (!numGangsSegments.empty ())
2078
+ computeOp.setNumGangsSegmentsAttr (
2079
+ builder.getDenseI32ArrayAttr (numGangsSegments));
2080
+ }
2081
+ if (!asyncDeviceTypes.empty ())
2082
+ computeOp.setAsyncDeviceTypeAttr (builder.getArrayAttr (asyncDeviceTypes));
2083
+ if (!asyncOnlyDeviceTypes.empty ())
2084
+ computeOp.setAsyncOnlyAttr (builder.getArrayAttr (asyncOnlyDeviceTypes));
2085
+
2086
+ if (!waitOperandsDeviceTypes.empty ())
2087
+ computeOp.setWaitOperandsDeviceTypeAttr (
2088
+ builder.getArrayAttr (waitOperandsDeviceTypes));
2089
+ if (!waitOperandsSegments.empty ())
2090
+ computeOp.setWaitOperandsSegmentsAttr (
2091
+ builder.getDenseI32ArrayAttr (waitOperandsSegments));
2092
+ if (!waitOnlyDeviceTypes.empty ())
2093
+ computeOp.setWaitOnlyAttr (builder.getArrayAttr (waitOnlyDeviceTypes));
2094
+
2033
2095
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2034
2096
if (!privatizations.empty ())
2035
2097
computeOp.setPrivatizationsAttr (
0 commit comments