@@ -1480,7 +1480,7 @@ getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
1480
1480
case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
1481
1481
return mlir::acc::DeviceType::Multicore;
1482
1482
}
1483
- return mlir::acc::DeviceType::None ;
1483
+ return mlir::acc::DeviceType::Default ;
1484
1484
}
1485
1485
1486
1486
static void gatherDeviceTypeAttrs (
@@ -1781,90 +1781,61 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
1781
1781
bool outerCombined = false ) {
1782
1782
1783
1783
// Parallel operation operands
1784
+ mlir::Value async;
1785
+ mlir::Value numWorkers;
1786
+ mlir::Value vectorLength;
1784
1787
mlir::Value ifCond;
1785
1788
mlir::Value selfCond;
1786
1789
mlir::Value waitDevnum;
1787
1790
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
1788
1791
copyEntryOperands, copyoutEntryOperands, createEntryOperands,
1789
- dataClauseOperands, numGangs, numWorkers, vectorLength, async;
1790
- llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
1791
- vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
1792
- waitOperandsDeviceTypes, waitOnlyDeviceTypes;
1793
- llvm::SmallVector<int32_t > numGangsSegments, waitOperandsSegments;
1792
+ dataClauseOperands, numGangs;
1794
1793
1795
1794
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
1796
1795
firstprivateOperands;
1797
1796
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
1798
1797
reductionRecipes;
1799
1798
1800
- // Self clause has optional values but can be present with
1799
+ // Async, wait and self clause have optional values but can be present with
1801
1800
// no value as well. When there is no value, the op has an attribute to
1802
1801
// represent the clause.
1802
+ bool addAsyncAttr = false ;
1803
+ bool addWaitAttr = false ;
1803
1804
bool addSelfAttr = false ;
1804
1805
1805
1806
bool hasDefaultNone = false ;
1806
1807
bool hasDefaultPresent = false ;
1807
1808
1808
1809
fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
1809
1810
1810
- // device_type attribute is set to `none` until a device_type clause is
1811
- // encountered.
1812
- auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
1813
- builder.getContext (), mlir::acc::DeviceType::None);
1814
-
1815
1811
// Lower clauses values mapped to operands.
1816
1812
// Keep track of each group of operands separatly as clauses can appear
1817
1813
// more than once.
1818
1814
for (const Fortran::parser::AccClause &clause : accClauseList.v ) {
1819
1815
mlir::Location clauseLocation = converter.genLocation (clause.source );
1820
1816
if (const auto *asyncClause =
1821
1817
std::get_if<Fortran::parser::AccClause::Async>(&clause.u )) {
1822
- const auto &asyncClauseValue = asyncClause->v ;
1823
- if (asyncClauseValue) { // async has a value.
1824
- async.push_back (fir::getBase (converter.genExprValue (
1825
- *Fortran::semantics::GetExpr (*asyncClauseValue), stmtCtx)));
1826
- asyncDeviceTypes.push_back (crtDeviceTypeAttr);
1827
- } else {
1828
- asyncOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1829
- }
1818
+ genAsyncClause (converter, asyncClause, async, addAsyncAttr, stmtCtx);
1830
1819
} else if (const auto *waitClause =
1831
1820
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
1832
- const auto &waitClauseValue = waitClause->v ;
1833
- if (waitClauseValue) { // wait has a value.
1834
- const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1835
- const auto &waitList =
1836
- std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1837
- auto crtWaitOperands = waitOperands.size ();
1838
- for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1839
- waitOperands.push_back (fir::getBase (converter.genExprValue (
1840
- *Fortran::semantics::GetExpr (value), stmtCtx)));
1841
- }
1842
- waitOperandsDeviceTypes.push_back (crtDeviceTypeAttr);
1843
- waitOperandsSegments.push_back (waitOperands.size () - crtWaitOperands);
1844
- } else {
1845
- waitOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1846
- }
1821
+ genWaitClause (converter, waitClause, waitOperands, waitDevnum,
1822
+ addWaitAttr, stmtCtx);
1847
1823
} else if (const auto *numGangsClause =
1848
1824
std::get_if<Fortran::parser::AccClause::NumGangs>(
1849
1825
&clause.u )) {
1850
- auto crtNumGangs = numGangs.size ();
1851
1826
for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v )
1852
1827
numGangs.push_back (fir::getBase (converter.genExprValue (
1853
1828
*Fortran::semantics::GetExpr (expr), stmtCtx)));
1854
- numGangsDeviceTypes.push_back (crtDeviceTypeAttr);
1855
- numGangsSegments.push_back (numGangs.size () - crtNumGangs);
1856
1829
} else if (const auto *numWorkersClause =
1857
1830
std::get_if<Fortran::parser::AccClause::NumWorkers>(
1858
1831
&clause.u )) {
1859
- numWorkers.push_back (fir::getBase (converter.genExprValue (
1860
- *Fortran::semantics::GetExpr (numWorkersClause->v ), stmtCtx)));
1861
- numWorkersDeviceTypes.push_back (crtDeviceTypeAttr);
1832
+ numWorkers = fir::getBase (converter.genExprValue (
1833
+ *Fortran::semantics::GetExpr (numWorkersClause->v ), stmtCtx));
1862
1834
} else if (const auto *vectorLengthClause =
1863
1835
std::get_if<Fortran::parser::AccClause::VectorLength>(
1864
1836
&clause.u )) {
1865
- vectorLength.push_back (fir::getBase (converter.genExprValue (
1866
- *Fortran::semantics::GetExpr (vectorLengthClause->v ), stmtCtx)));
1867
- vectorLengthDeviceTypes.push_back (crtDeviceTypeAttr);
1837
+ vectorLength = fir::getBase (converter.genExprValue (
1838
+ *Fortran::semantics::GetExpr (vectorLengthClause->v ), stmtCtx));
1868
1839
} else if (const auto *ifClause =
1869
1840
std::get_if<Fortran::parser::AccClause::If>(&clause.u )) {
1870
1841
genIfClause (converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -2015,27 +1986,18 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2015
1986
else if ((defaultClause->v ).v ==
2016
1987
llvm::acc::DefaultValue::ACC_Default_present)
2017
1988
hasDefaultPresent = true ;
2018
- } else if (const auto *deviceTypeClause =
2019
- std::get_if<Fortran::parser::AccClause::DeviceType>(
2020
- &clause.u )) {
2021
- const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
2022
- deviceTypeClause->v ;
2023
- assert (deviceTypeExprList.v .size () == 1 &&
2024
- " expect only one device_type expr" );
2025
- crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
2026
- builder.getContext (), getDeviceType (deviceTypeExprList.v .front ().v ));
2027
1989
}
2028
1990
}
2029
1991
2030
1992
// Prepare the operand segment size attribute and the operands value range.
2031
1993
llvm::SmallVector<mlir::Value, 8 > operands;
2032
1994
llvm::SmallVector<int32_t , 8 > operandSegments;
2033
- addOperands (operands, operandSegments, async);
1995
+ addOperand (operands, operandSegments, async);
2034
1996
addOperands (operands, operandSegments, waitOperands);
2035
1997
if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2036
1998
addOperands (operands, operandSegments, numGangs);
2037
- addOperands (operands, operandSegments, numWorkers);
2038
- addOperands (operands, operandSegments, vectorLength);
1999
+ addOperand (operands, operandSegments, numWorkers);
2000
+ addOperand (operands, operandSegments, vectorLength);
2039
2001
}
2040
2002
addOperand (operands, operandSegments, ifCond);
2041
2003
addOperand (operands, operandSegments, selfCond);
@@ -2056,6 +2018,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2056
2018
builder, currentLocation, eval, operands, operandSegments,
2057
2019
outerCombined);
2058
2020
2021
+ if (addAsyncAttr)
2022
+ computeOp.setAsyncAttrAttr (builder.getUnitAttr ());
2023
+ if (addWaitAttr)
2024
+ computeOp.setWaitAttrAttr (builder.getUnitAttr ());
2059
2025
if (addSelfAttr)
2060
2026
computeOp.setSelfAttrAttr (builder.getUnitAttr ());
2061
2027
@@ -2064,34 +2030,6 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2064
2030
if (hasDefaultPresent)
2065
2031
computeOp.setDefaultAttr (mlir::acc::ClauseDefaultValue::Present);
2066
2032
2067
- if constexpr (!std::is_same_v<Op, mlir::acc::SerialOp>) {
2068
- if (!numWorkersDeviceTypes.empty ())
2069
- computeOp.setNumWorkersDeviceTypeAttr (
2070
- mlir::ArrayAttr::get (builder.getContext (), numWorkersDeviceTypes));
2071
- if (!vectorLengthDeviceTypes.empty ())
2072
- computeOp.setVectorLengthDeviceTypeAttr (
2073
- mlir::ArrayAttr::get (builder.getContext (), vectorLengthDeviceTypes));
2074
- if (!numGangsDeviceTypes.empty ())
2075
- computeOp.setNumGangsDeviceTypeAttr (
2076
- mlir::ArrayAttr::get (builder.getContext (), numGangsDeviceTypes));
2077
- if (!numGangsSegments.empty ())
2078
- computeOp.setNumGangsSegmentsAttr (
2079
- builder.getDenseI32ArrayAttr (numGangsSegments));
2080
- }
2081
- if (!asyncDeviceTypes.empty ())
2082
- computeOp.setAsyncDeviceTypeAttr (builder.getArrayAttr (asyncDeviceTypes));
2083
- if (!asyncOnlyDeviceTypes.empty ())
2084
- computeOp.setAsyncOnlyAttr (builder.getArrayAttr (asyncOnlyDeviceTypes));
2085
-
2086
- if (!waitOperandsDeviceTypes.empty ())
2087
- computeOp.setWaitOperandsDeviceTypeAttr (
2088
- builder.getArrayAttr (waitOperandsDeviceTypes));
2089
- if (!waitOperandsSegments.empty ())
2090
- computeOp.setWaitOperandsSegmentsAttr (
2091
- builder.getDenseI32ArrayAttr (waitOperandsSegments));
2092
- if (!waitOnlyDeviceTypes.empty ())
2093
- computeOp.setWaitOnlyAttr (builder.getArrayAttr (waitOnlyDeviceTypes));
2094
-
2095
2033
if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) {
2096
2034
if (!privatizations.empty ())
2097
2035
computeOp.setPrivatizationsAttr (
0 commit comments