@@ -1464,6 +1464,24 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
1464
1464
}
1465
1465
}
1466
1466
1467
+ static void
1468
+ genAsyncClause (Fortran::lower::AbstractConverter &converter,
1469
+ const Fortran::parser::AccClause::Async *asyncClause,
1470
+ llvm::SmallVector<mlir::Value> &async,
1471
+ llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
1472
+ llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
1473
+ mlir::acc::DeviceTypeAttr deviceTypeAttr,
1474
+ Fortran::lower::StatementContext &stmtCtx) {
1475
+ const auto &asyncClauseValue = asyncClause->v ;
1476
+ if (asyncClauseValue) { // async has a value.
1477
+ async.push_back (fir::getBase (converter.genExprValue (
1478
+ *Fortran::semantics::GetExpr (*asyncClauseValue), stmtCtx)));
1479
+ asyncDeviceTypes.push_back (deviceTypeAttr);
1480
+ } else {
1481
+ asyncOnlyDeviceTypes.push_back (deviceTypeAttr);
1482
+ }
1483
+ }
1484
+
1467
1485
static mlir::acc::DeviceType
1468
1486
getDeviceType (Fortran::parser::AccDeviceTypeExpr::Device device) {
1469
1487
switch (device) {
@@ -1533,6 +1551,39 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
1533
1551
}
1534
1552
}
1535
1553
1554
+ static void
1555
+ genWaitClause (Fortran::lower::AbstractConverter &converter,
1556
+ const Fortran::parser::AccClause::Wait *waitClause,
1557
+ llvm::SmallVector<mlir::Value> &waitOperands,
1558
+ llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1559
+ llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1560
+ llvm::SmallVector<int32_t > &waitOperandsSegments,
1561
+ mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
1562
+ Fortran::lower::StatementContext &stmtCtx) {
1563
+ const auto &waitClauseValue = waitClause->v ;
1564
+ if (waitClauseValue) { // wait has a value.
1565
+ const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1566
+ const auto &waitList =
1567
+ std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1568
+ auto crtWaitOperands = waitOperands.size ();
1569
+ for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1570
+ waitOperands.push_back (fir::getBase (converter.genExprValue (
1571
+ *Fortran::semantics::GetExpr (value), stmtCtx)));
1572
+ }
1573
+ waitOperandsDeviceTypes.push_back (deviceTypeAttr);
1574
+ waitOperandsSegments.push_back (waitOperands.size () - crtWaitOperands);
1575
+
1576
+ // TODO: move to device_type model.
1577
+ const auto &waitDevnumValue =
1578
+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1579
+ if (waitDevnumValue)
1580
+ waitDevnum = fir::getBase (converter.genExprValue (
1581
+ *Fortran::semantics::GetExpr (*waitDevnumValue), stmtCtx));
1582
+ } else {
1583
+ waitOnlyDeviceTypes.push_back (deviceTypeAttr);
1584
+ }
1585
+ }
1586
+
1536
1587
static mlir::acc::LoopOp
1537
1588
createLoopOp (Fortran::lower::AbstractConverter &converter,
1538
1589
mlir::Location currentLocation,
@@ -1795,6 +1846,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
1795
1846
firstprivateOperands;
1796
1847
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
1797
1848
reductionRecipes;
1849
+ mlir::Value waitDevnum; // TODO not yet implemented on compute op.
1798
1850
1799
1851
// Self clause has optional values but can be present with
1800
1852
// no value as well. When there is no value, the op has an attribute to
@@ -1818,31 +1870,14 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
1818
1870
mlir::Location clauseLocation = converter.genLocation (clause.source );
1819
1871
if (const auto *asyncClause =
1820
1872
std::get_if<Fortran::parser::AccClause::Async>(&clause.u )) {
1821
- const auto &asyncClauseValue = asyncClause->v ;
1822
- if (asyncClauseValue) { // async has a value.
1823
- async.push_back (fir::getBase (converter.genExprValue (
1824
- *Fortran::semantics::GetExpr (*asyncClauseValue), stmtCtx)));
1825
- asyncDeviceTypes.push_back (crtDeviceTypeAttr);
1826
- } else {
1827
- asyncOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1828
- }
1873
+ genAsyncClause (converter, asyncClause, async, asyncDeviceTypes,
1874
+ asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
1829
1875
} else if (const auto *waitClause =
1830
1876
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
1831
- const auto &waitClauseValue = waitClause->v ;
1832
- if (waitClauseValue) { // wait has a value.
1833
- const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1834
- const auto &waitList =
1835
- std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1836
- auto crtWaitOperands = waitOperands.size ();
1837
- for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1838
- waitOperands.push_back (fir::getBase (converter.genExprValue (
1839
- *Fortran::semantics::GetExpr (value), stmtCtx)));
1840
- }
1841
- waitOperandsDeviceTypes.push_back (crtDeviceTypeAttr);
1842
- waitOperandsSegments.push_back (waitOperands.size () - crtWaitOperands);
1843
- } else {
1844
- waitOnlyDeviceTypes.push_back (crtDeviceTypeAttr);
1845
- }
1877
+ genWaitClause (converter, waitClause, waitOperands,
1878
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
1879
+ waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
1880
+ stmtCtx);
1846
1881
} else if (const auto *numGangsClause =
1847
1882
std::get_if<Fortran::parser::AccClause::NumGangs>(
1848
1883
&clause.u )) {
@@ -2126,21 +2161,24 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2126
2161
Fortran::semantics::SemanticsContext &semanticsContext,
2127
2162
Fortran::lower::StatementContext &stmtCtx,
2128
2163
const Fortran::parser::AccClauseList &accClauseList) {
2129
- mlir::Value ifCond, async, waitDevnum;
2164
+ mlir::Value ifCond, waitDevnum;
2130
2165
llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
2131
- copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands;
2132
-
2133
- // Async and wait have an optional value but can be present with
2134
- // no value as well. When there is no value, the op has an attribute to
2135
- // represent the clause.
2136
- bool addAsyncAttr = false ;
2137
- bool addWaitAttr = false ;
2166
+ copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands,
2167
+ async;
2168
+ llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
2169
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2170
+ llvm::SmallVector<int32_t > waitOperandsSegments;
2138
2171
2139
2172
bool hasDefaultNone = false ;
2140
2173
bool hasDefaultPresent = false ;
2141
2174
2142
2175
fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
2143
2176
2177
+ // device_type attribute is set to `none` until a device_type clause is
2178
+ // encountered.
2179
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
2180
+ builder.getContext (), mlir::acc::DeviceType::None);
2181
+
2144
2182
// Lower clauses values mapped to operands.
2145
2183
// Keep track of each group of operands separately as clauses can appear
2146
2184
// more than once.
@@ -2221,11 +2259,14 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2221
2259
dataClauseOperands.end ());
2222
2260
} else if (const auto *asyncClause =
2223
2261
std::get_if<Fortran::parser::AccClause::Async>(&clause.u )) {
2224
- genAsyncClause (converter, asyncClause, async, addAsyncAttr, stmtCtx);
2262
+ genAsyncClause (converter, asyncClause, async, asyncDeviceTypes,
2263
+ asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
2225
2264
} else if (const auto *waitClause =
2226
2265
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
2227
- genWaitClause (converter, waitClause, waitOperands, waitDevnum,
2228
- addWaitAttr, stmtCtx);
2266
+ genWaitClause (converter, waitClause, waitOperands,
2267
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2268
+ waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
2269
+ stmtCtx);
2229
2270
} else if (const auto *defaultClause =
2230
2271
std::get_if<Fortran::parser::AccClause::Default>(&clause.u )) {
2231
2272
if ((defaultClause->v ).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2239,7 +2280,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2239
2280
llvm::SmallVector<mlir::Value> operands;
2240
2281
llvm::SmallVector<int32_t > operandSegments;
2241
2282
addOperand (operands, operandSegments, ifCond);
2242
- addOperand (operands, operandSegments, async);
2283
+ addOperands (operands, operandSegments, async);
2243
2284
addOperand (operands, operandSegments, waitDevnum);
2244
2285
addOperands (operands, operandSegments, waitOperands);
2245
2286
addOperands (operands, operandSegments, dataClauseOperands);
@@ -2250,8 +2291,18 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2250
2291
auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
2251
2292
builder, currentLocation, eval, operands, operandSegments);
2252
2293
2253
- dataOp.setAsyncAttr (addAsyncAttr);
2254
- dataOp.setWaitAttr (addWaitAttr);
2294
+ if (!asyncDeviceTypes.empty ())
2295
+ dataOp.setAsyncDeviceTypeAttr (builder.getArrayAttr (asyncDeviceTypes));
2296
+ if (!asyncOnlyDeviceTypes.empty ())
2297
+ dataOp.setAsyncOnlyAttr (builder.getArrayAttr (asyncOnlyDeviceTypes));
2298
+ if (!waitOperandsDeviceTypes.empty ())
2299
+ dataOp.setWaitOperandsDeviceTypeAttr (
2300
+ builder.getArrayAttr (waitOperandsDeviceTypes));
2301
+ if (!waitOperandsSegments.empty ())
2302
+ dataOp.setWaitOperandsSegmentsAttr (
2303
+ builder.getDenseI32ArrayAttr (waitOperandsSegments));
2304
+ if (!waitOnlyDeviceTypes.empty ())
2305
+ dataOp.setWaitOnlyAttr (builder.getArrayAttr (waitOnlyDeviceTypes));
2255
2306
2256
2307
if (hasDefaultNone)
2257
2308
dataOp.setDefaultAttr (mlir::acc::ClauseDefaultValue::None);
0 commit comments