@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
1480
1480
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
1481
1481
return mlir::acc::DeviceType::Multicore;
1482
1482
}
1483
- return mlir::acc::DeviceType::Default ;
1483
+ return mlir::acc::DeviceType::None ;
1484
1484
}
1485
1485
1486
1486
static void gatherDeviceTypeAttrs (
@@ -1781,61 +1781,89 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
1781
1781
bool outerCombined = false ) {
1782
1782
1783
1783
// Parallel operation operands
1784
- mlir::Value async;
1785
- mlir::Value numWorkers;
1786
- mlir::Value vectorLength;
1787
1784
mlir::Value ifCond;
1788
1785
mlir::Value selfCond;
1789
- mlir::Value waitDevnum;
1790
1786
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
1791
1787
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
1792
- dataClauseOperands, numGangs;
1788
+ dataClauseOperands, numGangs, numWorkers, vectorLength, async;
1789
+ llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
1790
+ vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
1791
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes;
1792
+ llvm::SmallVector<int32_t > numGangsSegments, waitOperandsSegments;
1793
1793
1794
1794
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
1795
1795
firstprivateOperands;
1796
1796
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
1797
1797
reductionRecipes;
1798
1798
1799
- // Async, wait and self clause have optional values but can be present with
1799
+ // Self clause has optional values but can be present with
1800
1800
// no value as well. When there is no value, the op has an attribute to
1801
1801
// represent the clause.
1802
- bool addAsyncAttr = false ;
1803
- bool addWaitAttr = false ;
1804
1802
bool addSelfAttr = false ;
1805
1803
1806
1804
bool hasDefaultNone = false ;
1807
1805
bool hasDefaultPresent = false ;
1808
1806
1809
1807
fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
1810
1808
1809
+ // device_type attribute is set to `none` until a device_type clause is
1810
+ // encountered.
1811
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
1812
+ builder.getContext (), mlir::acc::DeviceType::None);
1813
+
1811
1814
// Lower clauses values mapped to operands.
1812
1815
// Keep track of each group of operands separatly as clauses can appear
1813
1816
// more than once.
1814
1817
for (const Fortran::parser::AccClause &clause : accClauseList.v ) {
1815
1818
mlir::Location clauseLocation = converter.genLocation (clause.source );
1816
1819
if (const auto *asyncClause =
1817
1820
std::get_if<Fortran::parser::AccClause::Async>(&clause.u )) {
1818
- genAsyncClause (converter, asyncClause, async, addAsyncAttr, stmtCtx);
1821
+ const auto &asyncClauseValue = asyncClause->v ;
1822
+ if (asyncClauseValue) { // async has a value.
1823
+ async.push_back (fir::getBase (converter.genExprValue (
1824
+ *Fortran::semantics::GetExpr (*asyncClauseValue), stmtCtx)));
1825
+ asyncDeviceTypes.push_back (crtDeviceTypeAttr);
1826
+ } else {
1827
+ asyncOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1828
+ }
1819
1829
} else if (const auto *waitClause =
1820
1830
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
1821
- genWaitClause (converter, waitClause, waitOperands, waitDevnum,
1822
- addWaitAttr, stmtCtx);
1831
+ const auto &waitClauseValue = waitClause->v ;
1832
+ if (waitClauseValue) { // wait has a value.
1833
+ const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1834
+ const auto &waitList =
1835
+ std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1836
+ auto crtWaitOperands = waitOperands.size ();
1837
+ for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1838
+ waitOperands.push_back (fir::getBase (converter.genExprValue (
1839
+ *Fortran::semantics::GetExpr (value), stmtCtx)));
1840
+ }
1841
+ waitOperandsDeviceTypes.push_back (crtDeviceTypeAttr);
1842
+ waitOperandsSegments.push_back (waitOperands.size () - crtWaitOperands);
1843
+ } else {
1844
+ waitOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1845
+ }
1823
1846
} else if (const auto *numGangsClause =
1824
1847
std::get_if<Fortran::parser::AccClause::NumGangs>(
1825
1848
&clause.u )) {
1849
+ auto crtNumGangs = numGangs.size ();
1826
1850
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v )
1827
1851
numGangs.push_back (fir::getBase (converter.genExprValue (
1828
1852
*Fortran::semantics::GetExpr (expr), stmtCtx)));
1853
+ numGangsDeviceTypes.push_back (crtDeviceTypeAttr);
1854
+ numGangsSegments.push_back (numGangs.size () - crtNumGangs);
1829
1855
} else if (const auto *numWorkersClause =
1830
1856
std::get_if<Fortran::parser::AccClause::NumWorkers>(
1831
1857
&clause.u )) {
1832
- numWorkers = fir::getBase (converter.genExprValue (
1833
- *Fortran::semantics::GetExpr (numWorkersClause->v ), stmtCtx));
1858
+ numWorkers.push_back (fir::getBase (converter.genExprValue (
1859
+ *Fortran::semantics::GetExpr (numWorkersClause->v ), stmtCtx)));
1860
+ numWorkersDeviceTypes.push_back (crtDeviceTypeAttr);
1834
1861
} else if (const auto *vectorLengthClause =
1835
1862
std::get_if<Fortran::parser::AccClause::VectorLength>(
1836
1863
&clause.u )) {
1837
- vectorLength = fir::getBase (converter.genExprValue (
1838
- *Fortran::semantics::GetExpr (vectorLengthClause->v ), stmtCtx));
1864
+ vectorLength.push_back (fir::getBase (converter.genExprValue (
1865
+ *Fortran::semantics::GetExpr (vectorLengthClause->v ), stmtCtx)));
1866
+ vectorLengthDeviceTypes.push_back (crtDeviceTypeAttr);
1839
1867
} else if (const auto *ifClause =
1840
1868
std::get_if<Fortran::parser::AccClause::If>(&clause.u )) {
1841
1869
genIfClause (converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -1986,18 +2014,27 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
1986
2014
else if ((defaultClause->v ).v ==
1987
2015
llvm::acc::DefaultValue::ACC_Default_present)
1988
2016
hasDefaultPresent = true ;
2017
+ } else if (const auto *deviceTypeClause =
2018
+ std::get_if<Fortran::parser::AccClause::DeviceType>(
2019
+ &clause.u )) {
2020
+ const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
2021
+ deviceTypeClause->v ;
2022
+ assert (deviceTypeExprList.v .size () == 1 &&
2023
+ " expect only one device_type expr" );
2024
+ crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
2025
+ builder.getContext (), getDeviceType (deviceTypeExprList.v .front ().v ));
1989
2026
}
1990
2027
}
1991
2028
1992
2029
// Prepare the operand segment size attribute and the operands value range.
1993
2030
llvm::SmallVector<mlir::Value, 8 > operands;
1994
2031
llvm::SmallVector<int32_t , 8 > operandSegments;
1995
- addOperand (operands, operandSegments, async);
2032
+ addOperands (operands, operandSegments, async);
1996
2033
addOperands (operands, operandSegments, waitOperands);
1997
2034
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
1998
2035
addOperands (operands, operandSegments, numGangs);
1999
- addOperand (operands, operandSegments, numWorkers);
2000
- addOperand (operands, operandSegments, vectorLength);
2036
+ addOperands (operands, operandSegments, numWorkers);
2037
+ addOperands (operands, operandSegments, vectorLength);
2001
2038
}
2002
2039
addOperand (operands, operandSegments, ifCond);
2003
2040
addOperand (operands, operandSegments, selfCond);
@@ -2018,10 +2055,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2018
2055
builder, currentLocation, eval, operands, operandSegments,
2019
2056
outerCombined);
2020
2057
2021
- if (addAsyncAttr)
2022
- computeOp.setAsyncAttrAttr (builder.getUnitAttr ());
2023
- if (addWaitAttr)
2024
- computeOp.setWaitAttrAttr (builder.getUnitAttr ());
2025
2058
if (addSelfAttr)
2026
2059
computeOp.setSelfAttrAttr (builder.getUnitAttr ());
2027
2060
@@ -2030,6 +2063,34 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2030
2063
if (hasDefaultPresent)
2031
2064
computeOp.setDefaultAttr (mlir::acc::ClauseDefaultValue::Present);
2032
2065
2066
+ if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2067
+ if (!numWorkersDeviceTypes.empty ())
2068
+ computeOp.setNumWorkersDeviceTypeAttr (
2069
+ mlir::ArrayAttr::get (builder.getContext (), numWorkersDeviceTypes));
2070
+ if (!vectorLengthDeviceTypes.empty ())
2071
+ computeOp.setVectorLengthDeviceTypeAttr (
2072
+ mlir::ArrayAttr::get (builder.getContext (), vectorLengthDeviceTypes));
2073
+ if (!numGangsDeviceTypes.empty ())
2074
+ computeOp.setNumGangsDeviceTypeAttr (
2075
+ mlir::ArrayAttr::get (builder.getContext (), numGangsDeviceTypes));
2076
+ if (!numGangsSegments.empty ())
2077
+ computeOp.setNumGangsSegmentsAttr (
2078
+ builder.getDenseI32ArrayAttr (numGangsSegments));
2079
+ }
2080
+ if (!asyncDeviceTypes.empty ())
2081
+ computeOp.setAsyncDeviceTypeAttr (builder.getArrayAttr (asyncDeviceTypes));
2082
+ if (!asyncOnlyDeviceTypes.empty ())
2083
+ computeOp.setAsyncOnlyAttr (builder.getArrayAttr (asyncOnlyDeviceTypes));
2084
+
2085
+ if (!waitOperandsDeviceTypes.empty ())
2086
+ computeOp.setWaitOperandsDeviceTypeAttr (
2087
+ builder.getArrayAttr (waitOperandsDeviceTypes));
2088
+ if (!waitOperandsSegments.empty ())
2089
+ computeOp.setWaitOperandsSegmentsAttr (
2090
+ builder.getDenseI32ArrayAttr (waitOperandsSegments));
2091
+ if (!waitOnlyDeviceTypes.empty ())
2092
+ computeOp.setWaitOnlyAttr (builder.getArrayAttr (waitOnlyDeviceTypes));
2093
+
2033
2094
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2034
2095
if (!privatizations.empty ())
2035
2096
computeOp.setPrivatizationsAttr (
0 commit comments