@@ -171,7 +171,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
171
171
builder, loc, registerFuncOp.getArgument (0 ), asFortranDesc, bounds,
172
172
/* structured=*/ false , /* implicit=*/ true ,
173
173
mlir::acc::DataClause::acc_update_device, descTy);
174
- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
174
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 1 };
175
175
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
176
176
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
177
177
@@ -245,7 +245,7 @@ static void createDeclareDeallocFuncWithArg(
245
245
builder, loc, loadOp, asFortran, bounds,
246
246
/* structured=*/ false , /* implicit=*/ true ,
247
247
mlir::acc::DataClause::acc_update_device, loadOp.getType ());
248
- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
248
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 1 };
249
249
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
250
250
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
251
251
modBuilder.setInsertionPointAfter (postDeallocOp);
@@ -1559,39 +1559,44 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
1559
1559
}
1560
1560
}
1561
1561
1562
- static void
1563
- genWaitClause ( Fortran::lower::AbstractConverter &converter,
1564
- const Fortran::parser::AccClause::Wait *waitClause,
1565
- llvm::SmallVector<mlir::Value> &waitOperands,
1566
- llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1567
- llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1568
- llvm::SmallVector<int32_t > &waitOperandsSegments ,
1569
- mlir::Value &waitDevnum ,
1570
- llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
1571
- Fortran::lower::StatementContext &stmtCtx) {
1562
+ static void genWaitClauseWithDeviceType (
1563
+ Fortran::lower::AbstractConverter &converter,
1564
+ const Fortran::parser::AccClause::Wait *waitClause,
1565
+ llvm::SmallVector<mlir::Value> &waitOperands,
1566
+ llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1567
+ llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1568
+ llvm::SmallVector<bool > &hasDevnums ,
1569
+ llvm::SmallVector< int32_t > &waitOperandsSegments ,
1570
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
1571
+ Fortran::lower::StatementContext &stmtCtx) {
1572
1572
const auto &waitClauseValue = waitClause->v ;
1573
1573
if (waitClauseValue) { // wait has a value.
1574
+ llvm::SmallVector<mlir::Value> waitValues;
1575
+
1574
1576
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1577
+ const auto &waitDevnumValue =
1578
+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1579
+ bool hasDevnum = false ;
1580
+ if (waitDevnumValue) {
1581
+ waitValues.push_back (fir::getBase (converter.genExprValue (
1582
+ *Fortran::semantics::GetExpr (*waitDevnumValue), stmtCtx)));
1583
+ hasDevnum = true ;
1584
+ }
1585
+
1575
1586
const auto &waitList =
1576
1587
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1577
- llvm::SmallVector<mlir::Value> waitValues;
1578
1588
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1579
1589
waitValues.push_back (fir::getBase (converter.genExprValue (
1580
1590
*Fortran::semantics::GetExpr (value), stmtCtx)));
1581
1591
}
1592
+
1582
1593
for (auto deviceTypeAttr : deviceTypeAttrs) {
1583
1594
for (auto value : waitValues)
1584
1595
waitOperands.push_back (value);
1585
1596
waitOperandsDeviceTypes.push_back (deviceTypeAttr);
1586
1597
waitOperandsSegments.push_back (waitValues.size ());
1598
+ hasDevnums.push_back (hasDevnum);
1587
1599
}
1588
-
1589
- // TODO: move to device_type model.
1590
- const auto &waitDevnumValue =
1591
- std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t );
1592
- if (waitDevnumValue)
1593
- waitDevnum = fir::getBase (converter.genExprValue (
1594
- *Fortran::semantics::GetExpr (*waitDevnumValue), stmtCtx));
1595
1600
} else {
1596
1601
for (auto deviceTypeAttr : deviceTypeAttrs)
1597
1602
waitOnlyDeviceTypes.push_back (deviceTypeAttr);
@@ -2093,12 +2098,12 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2093
2098
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
2094
2099
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2095
2100
llvm::SmallVector<int32_t > numGangsSegments, waitOperandsSegments;
2101
+ llvm::SmallVector<bool > hasWaitDevnums;
2096
2102
2097
2103
llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
2098
2104
firstprivateOperands;
2099
2105
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
2100
2106
reductionRecipes;
2101
- mlir::Value waitDevnum; // TODO not yet implemented on compute op.
2102
2107
2103
2108
// Self clause has optional values but can be present with
2104
2109
// no value as well. When there is no value, the op has an attribute to
@@ -2128,9 +2133,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2128
2133
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
2129
2134
} else if (const auto *waitClause =
2130
2135
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
2131
- genWaitClause (converter, waitClause, waitOperands,
2132
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2133
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
2136
+ genWaitClauseWithDeviceType (converter, waitClause, waitOperands,
2137
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2138
+ hasWaitDevnums, waitOperandsSegments,
2139
+ crtDeviceTypes, stmtCtx);
2134
2140
} else if (const auto *numGangsClause =
2135
2141
std::get_if<Fortran::parser::AccClause::NumGangs>(
2136
2142
&clause.u )) {
@@ -2372,7 +2378,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2372
2378
builder.getDenseI32ArrayAttr (numGangsSegments));
2373
2379
}
2374
2380
if (!asyncDeviceTypes.empty ())
2375
- computeOp.setAsyncDeviceTypeAttr (builder.getArrayAttr (asyncDeviceTypes));
2381
+ computeOp.setAsyncOperandsDeviceTypeAttr (
2382
+ builder.getArrayAttr (asyncDeviceTypes));
2376
2383
if (!asyncOnlyDeviceTypes.empty ())
2377
2384
computeOp.setAsyncOnlyAttr (builder.getArrayAttr (asyncOnlyDeviceTypes));
2378
2385
@@ -2382,6 +2389,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
2382
2389
if (!waitOperandsSegments.empty ())
2383
2390
computeOp.setWaitOperandsSegmentsAttr (
2384
2391
builder.getDenseI32ArrayAttr (waitOperandsSegments));
2392
+ if (!hasWaitDevnums.empty ())
2393
+ computeOp.setHasWaitDevnumAttr (builder.getBoolArrayAttr (hasWaitDevnums));
2385
2394
if (!waitOnlyDeviceTypes.empty ())
2386
2395
computeOp.setWaitOnlyAttr (builder.getArrayAttr (waitOnlyDeviceTypes));
2387
2396
@@ -2427,6 +2436,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2427
2436
llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
2428
2437
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2429
2438
llvm::SmallVector<int32_t > waitOperandsSegments;
2439
+ llvm::SmallVector<bool > hasWaitDevnums;
2430
2440
2431
2441
bool hasDefaultNone = false ;
2432
2442
bool hasDefaultPresent = false ;
@@ -2523,9 +2533,10 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2523
2533
asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
2524
2534
} else if (const auto *waitClause =
2525
2535
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
2526
- genWaitClause (converter, waitClause, waitOperands,
2527
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2528
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
2536
+ genWaitClauseWithDeviceType (converter, waitClause, waitOperands,
2537
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2538
+ hasWaitDevnums, waitOperandsSegments,
2539
+ crtDeviceTypes, stmtCtx);
2529
2540
} else if (const auto *defaultClause =
2530
2541
std::get_if<Fortran::parser::AccClause::Default>(&clause.u )) {
2531
2542
if ((defaultClause->v ).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2545,7 +2556,6 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2545
2556
llvm::SmallVector<int32_t > operandSegments;
2546
2557
addOperand (operands, operandSegments, ifCond);
2547
2558
addOperands (operands, operandSegments, async);
2548
- addOperand (operands, operandSegments, waitDevnum);
2549
2559
addOperands (operands, operandSegments, waitOperands);
2550
2560
addOperands (operands, operandSegments, dataClauseOperands);
2551
2561
@@ -2557,7 +2567,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2557
2567
operandSegments);
2558
2568
2559
2569
if (!asyncDeviceTypes.empty ())
2560
- dataOp.setAsyncDeviceTypeAttr (builder.getArrayAttr (asyncDeviceTypes));
2570
+ dataOp.setAsyncOperandsDeviceTypeAttr (
2571
+ builder.getArrayAttr (asyncDeviceTypes));
2561
2572
if (!asyncOnlyDeviceTypes.empty ())
2562
2573
dataOp.setAsyncOnlyAttr (builder.getArrayAttr (asyncOnlyDeviceTypes));
2563
2574
if (!waitOperandsDeviceTypes.empty ())
@@ -2566,6 +2577,8 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
2566
2577
if (!waitOperandsSegments.empty ())
2567
2578
dataOp.setWaitOperandsSegmentsAttr (
2568
2579
builder.getDenseI32ArrayAttr (waitOperandsSegments));
2580
+ if (!hasWaitDevnums.empty ())
2581
+ dataOp.setHasWaitDevnumAttr (builder.getBoolArrayAttr (hasWaitDevnums));
2569
2582
if (!waitOnlyDeviceTypes.empty ())
2570
2583
dataOp.setWaitOnlyAttr (builder.getArrayAttr (waitOnlyDeviceTypes));
2571
2584
@@ -3007,6 +3020,11 @@ getArrayAttr(fir::FirOpBuilder &b,
3007
3020
return attributes.empty () ? nullptr : b.getArrayAttr (attributes);
3008
3021
}
3009
3022
3023
+ static inline mlir::ArrayAttr
3024
+ getBoolArrayAttr (fir::FirOpBuilder &b, llvm::SmallVector<bool > &values) {
3025
+ return values.empty () ? nullptr : b.getBoolArrayAttr (values);
3026
+ }
3027
+
3010
3028
static inline mlir::DenseI32ArrayAttr
3011
3029
getDenseI32ArrayAttr (fir::FirOpBuilder &builder,
3012
3030
llvm::SmallVector<int32_t > &values) {
@@ -3024,6 +3042,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
3024
3042
waitOperands, deviceTypeOperands, asyncOperands;
3025
3043
llvm::SmallVector<mlir::Attribute> asyncOperandsDeviceTypes,
3026
3044
asyncOnlyDeviceTypes, waitOperandsDeviceTypes, waitOnlyDeviceTypes;
3045
+ llvm::SmallVector<bool > hasWaitDevnums;
3027
3046
llvm::SmallVector<int32_t > waitOperandsSegments;
3028
3047
3029
3048
fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
@@ -3051,9 +3070,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
3051
3070
crtDeviceTypes, stmtCtx);
3052
3071
} else if (const auto *waitClause =
3053
3072
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
3054
- genWaitClause (converter, waitClause, waitOperands,
3055
- waitOperandsDeviceTypes, waitOnlyDeviceTypes,
3056
- waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
3073
+ genWaitClauseWithDeviceType (converter, waitClause, waitOperands,
3074
+ waitOperandsDeviceTypes, waitOnlyDeviceTypes,
3075
+ hasWaitDevnums, waitOperandsSegments,
3076
+ crtDeviceTypes, stmtCtx);
3057
3077
} else if (const auto *deviceTypeClause =
3058
3078
std::get_if<Fortran::parser::AccClause::DeviceType>(
3059
3079
&clause.u )) {
@@ -3092,9 +3112,10 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
3092
3112
builder.create <mlir::acc::UpdateOp>(
3093
3113
currentLocation, ifCond, asyncOperands,
3094
3114
getArrayAttr (builder, asyncOperandsDeviceTypes),
3095
- getArrayAttr (builder, asyncOnlyDeviceTypes), waitDevnum, waitOperands,
3115
+ getArrayAttr (builder, asyncOnlyDeviceTypes), waitOperands,
3096
3116
getDenseI32ArrayAttr (builder, waitOperandsSegments),
3097
3117
getArrayAttr (builder, waitOperandsDeviceTypes),
3118
+ getBoolArrayAttr (builder, hasWaitDevnums),
3098
3119
getArrayAttr (builder, waitOnlyDeviceTypes), dataClauseOperands,
3099
3120
ifPresent);
3100
3121
@@ -3268,7 +3289,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
3268
3289
builder, loc, addrOp, asFortranDesc, bounds,
3269
3290
/* structured=*/ false , /* implicit=*/ true ,
3270
3291
mlir::acc::DataClause::acc_update_device, addrOp.getType ());
3271
- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
3292
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 1 };
3272
3293
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
3273
3294
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
3274
3295
@@ -3349,7 +3370,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
3349
3370
builder, loc, addrOp, asFortran, bounds,
3350
3371
/* structured=*/ false , /* implicit=*/ true ,
3351
3372
mlir::acc::DataClause::acc_update_device, addrOp.getType ());
3352
- llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 0 , 1 };
3373
+ llvm::SmallVector<int32_t > operandSegments{0 , 0 , 0 , 1 };
3353
3374
llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult ()};
3354
3375
createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
3355
3376
modBuilder.setInsertionPointAfter (postDeallocOp);
0 commit comments