Skip to content

Commit 71ec301

Browse files
authored
[mlir][openacc] Add device_type support for data operation (#76126)
Following #75864, this patch adds device_type support to the data operation on the async and wait operands and attributes.
1 parent 3096353 commit 71ec301

File tree

5 files changed

+176
-55
lines changed

5 files changed

+176
-55
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 88 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,24 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
14641464
}
14651465
}
14661466

1467+
static void
1468+
genAsyncClause(Fortran::lower::AbstractConverter &converter,
1469+
const Fortran::parser::AccClause::Async *asyncClause,
1470+
llvm::SmallVector<mlir::Value> &async,
1471+
llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
1472+
llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
1473+
mlir::acc::DeviceTypeAttr deviceTypeAttr,
1474+
Fortran::lower::StatementContext &stmtCtx) {
1475+
const auto &asyncClauseValue = asyncClause->v;
1476+
if (asyncClauseValue) { // async has a value.
1477+
async.push_back(fir::getBase(converter.genExprValue(
1478+
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
1479+
asyncDeviceTypes.push_back(deviceTypeAttr);
1480+
} else {
1481+
asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
1482+
}
1483+
}
1484+
14671485
static mlir::acc::DeviceType
14681486
getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
14691487
switch (device) {
@@ -1533,6 +1551,39 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
15331551
}
15341552
}
15351553

1554+
static void
1555+
genWaitClause(Fortran::lower::AbstractConverter &converter,
1556+
const Fortran::parser::AccClause::Wait *waitClause,
1557+
llvm::SmallVector<mlir::Value> &waitOperands,
1558+
llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
1559+
llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
1560+
llvm::SmallVector<int32_t> &waitOperandsSegments,
1561+
mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
1562+
Fortran::lower::StatementContext &stmtCtx) {
1563+
const auto &waitClauseValue = waitClause->v;
1564+
if (waitClauseValue) { // wait has a value.
1565+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1566+
const auto &waitList =
1567+
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1568+
auto crtWaitOperands = waitOperands.size();
1569+
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1570+
waitOperands.push_back(fir::getBase(converter.genExprValue(
1571+
*Fortran::semantics::GetExpr(value), stmtCtx)));
1572+
}
1573+
waitOperandsDeviceTypes.push_back(deviceTypeAttr);
1574+
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
1575+
1576+
// TODO: move to device_type model.
1577+
const auto &waitDevnumValue =
1578+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1579+
if (waitDevnumValue)
1580+
waitDevnum = fir::getBase(converter.genExprValue(
1581+
*Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
1582+
} else {
1583+
waitOnlyDeviceTypes.push_back(deviceTypeAttr);
1584+
}
1585+
}
1586+
15361587
static mlir::acc::LoopOp
15371588
createLoopOp(Fortran::lower::AbstractConverter &converter,
15381589
mlir::Location currentLocation,
@@ -1795,6 +1846,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
17951846
firstprivateOperands;
17961847
llvm::SmallVector<mlir::Attribute> privatizations, firstPrivatizations,
17971848
reductionRecipes;
1849+
mlir::Value waitDevnum; // TODO not yet implemented on compute op.
17981850

17991851
// Self clause has optional values but can be present with
18001852
// no value as well. When there is no value, the op has an attribute to
@@ -1818,31 +1870,14 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
18181870
mlir::Location clauseLocation = converter.genLocation(clause.source);
18191871
if (const auto *asyncClause =
18201872
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
1821-
const auto &asyncClauseValue = asyncClause->v;
1822-
if (asyncClauseValue) { // async has a value.
1823-
async.push_back(fir::getBase(converter.genExprValue(
1824-
*Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
1825-
asyncDeviceTypes.push_back(crtDeviceTypeAttr);
1826-
} else {
1827-
asyncOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1828-
}
1873+
genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
1874+
asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
18291875
} else if (const auto *waitClause =
18301876
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
1831-
const auto &waitClauseValue = waitClause->v;
1832-
if (waitClauseValue) { // wait has a value.
1833-
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
1834-
const auto &waitList =
1835-
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
1836-
auto crtWaitOperands = waitOperands.size();
1837-
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
1838-
waitOperands.push_back(fir::getBase(converter.genExprValue(
1839-
*Fortran::semantics::GetExpr(value), stmtCtx)));
1840-
}
1841-
waitOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1842-
waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
1843-
} else {
1844-
waitOnlyDeviceTypes.push_back(crtDeviceTypeAttr);
1845-
}
1877+
genWaitClause(converter, waitClause, waitOperands,
1878+
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
1879+
waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
1880+
stmtCtx);
18461881
} else if (const auto *numGangsClause =
18471882
std::get_if<Fortran::parser::AccClause::NumGangs>(
18481883
&clause.u)) {
@@ -2126,21 +2161,24 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
21262161
Fortran::semantics::SemanticsContext &semanticsContext,
21272162
Fortran::lower::StatementContext &stmtCtx,
21282163
const Fortran::parser::AccClauseList &accClauseList) {
2129-
mlir::Value ifCond, async, waitDevnum;
2164+
mlir::Value ifCond, waitDevnum;
21302165
llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
2131-
copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands;
2132-
2133-
// Async and wait have an optional value but can be present with
2134-
// no value as well. When there is no value, the op has an attribute to
2135-
// represent the clause.
2136-
bool addAsyncAttr = false;
2137-
bool addWaitAttr = false;
2166+
copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands,
2167+
async;
2168+
llvm::SmallVector<mlir::Attribute> asyncDeviceTypes, asyncOnlyDeviceTypes,
2169+
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
2170+
llvm::SmallVector<int32_t> waitOperandsSegments;
21382171

21392172
bool hasDefaultNone = false;
21402173
bool hasDefaultPresent = false;
21412174

21422175
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
21432176

2177+
// device_type attribute is set to `none` until a device_type clause is
2178+
// encountered.
2179+
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
2180+
builder.getContext(), mlir::acc::DeviceType::None);
2181+
21442182
// Lower clauses values mapped to operands.
21452183
// Keep track of each group of operands separately as clauses can appear
21462184
// more than once.
@@ -2221,11 +2259,14 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
22212259
dataClauseOperands.end());
22222260
} else if (const auto *asyncClause =
22232261
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
2224-
genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx);
2262+
genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
2263+
asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
22252264
} else if (const auto *waitClause =
22262265
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
2227-
genWaitClause(converter, waitClause, waitOperands, waitDevnum,
2228-
addWaitAttr, stmtCtx);
2266+
genWaitClause(converter, waitClause, waitOperands,
2267+
waitOperandsDeviceTypes, waitOnlyDeviceTypes,
2268+
waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
2269+
stmtCtx);
22292270
} else if(const auto *defaultClause =
22302271
std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
22312272
if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
@@ -2239,7 +2280,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
22392280
llvm::SmallVector<mlir::Value> operands;
22402281
llvm::SmallVector<int32_t> operandSegments;
22412282
addOperand(operands, operandSegments, ifCond);
2242-
addOperand(operands, operandSegments, async);
2283+
addOperands(operands, operandSegments, async);
22432284
addOperand(operands, operandSegments, waitDevnum);
22442285
addOperands(operands, operandSegments, waitOperands);
22452286
addOperands(operands, operandSegments, dataClauseOperands);
@@ -2250,8 +2291,18 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
22502291
auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
22512292
builder, currentLocation, eval, operands, operandSegments);
22522293

2253-
dataOp.setAsyncAttr(addAsyncAttr);
2254-
dataOp.setWaitAttr(addWaitAttr);
2294+
if (!asyncDeviceTypes.empty())
2295+
dataOp.setAsyncDeviceTypeAttr(builder.getArrayAttr(asyncDeviceTypes));
2296+
if (!asyncOnlyDeviceTypes.empty())
2297+
dataOp.setAsyncOnlyAttr(builder.getArrayAttr(asyncOnlyDeviceTypes));
2298+
if (!waitOperandsDeviceTypes.empty())
2299+
dataOp.setWaitOperandsDeviceTypeAttr(
2300+
builder.getArrayAttr(waitOperandsDeviceTypes));
2301+
if (!waitOperandsSegments.empty())
2302+
dataOp.setWaitOperandsSegmentsAttr(
2303+
builder.getDenseI32ArrayAttr(waitOperandsSegments));
2304+
if (!waitOnlyDeviceTypes.empty())
2305+
dataOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes));
22552306

22562307
if (hasDefaultNone)
22572308
dataOp.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);

flang/test/Lower/OpenACC/acc-data.f90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ subroutine acc_data
153153
!$acc end data
154154

155155
! CHECK: acc.data dataOperands(%{{.*}}) {
156-
! CHECK: } attributes {asyncAttr}
156+
! CHECK: } attributes {asyncOnly = [#acc.device_type<none>]}
157157

158158
!$acc data present(a) async(1)
159159
!$acc end data
@@ -165,18 +165,18 @@ subroutine acc_data
165165
!$acc end data
166166

167167
! CHECK: acc.data dataOperands(%{{.*}}) {
168-
! CHECK: } attributes {waitAttr}
168+
! CHECK: } attributes {waitOnly = [#acc.device_type<none>]}
169169

170170
!$acc data present(a) wait(1)
171171
!$acc end data
172172

173-
! CHECK: acc.data dataOperands(%{{.*}}) wait(%{{.*}} : i32) {
173+
! CHECK: acc.data dataOperands(%{{.*}}) wait({%{{.*}} : i32}) {
174174
! CHECK: }{{$}}
175175

176176
!$acc data present(a) wait(devnum: 0: 1)
177177
!$acc end data
178178

179-
! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait(%{{.*}} : i32) {
179+
! CHECK: acc.data dataOperands(%{{.*}}) wait_devnum(%{{.*}} : i32) wait({%{{.*}} : i32}) {
180180
! CHECK: }{{$}}
181181

182182
!$acc data default(none)

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,13 +1236,16 @@ def OpenACC_DataOp : OpenACC_Op<"data",
12361236

12371237

12381238
let arguments = (ins Optional<I1>:$ifCond,
1239-
Optional<IntOrIndex>:$async,
1240-
UnitAttr:$asyncAttr,
1241-
Optional<IntOrIndex>:$waitDevnum,
1242-
Variadic<IntOrIndex>:$waitOperands,
1243-
UnitAttr:$waitAttr,
1244-
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
1245-
OptionalAttr<DefaultValueAttr>:$defaultAttr);
1239+
Variadic<IntOrIndex>:$async,
1240+
OptionalAttr<DeviceTypeArrayAttr>:$asyncDeviceType,
1241+
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
1242+
Optional<IntOrIndex>:$waitDevnum,
1243+
Variadic<IntOrIndex>:$waitOperands,
1244+
OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
1245+
OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
1246+
OptionalAttr<DeviceTypeArrayAttr>:$waitOnly,
1247+
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
1248+
OptionalAttr<DefaultValueAttr>:$defaultAttr);
12461249

12471250
let regions = (region AnyRegion:$region);
12481251

@@ -1252,15 +1255,41 @@ def OpenACC_DataOp : OpenACC_Op<"data",
12521255

12531256
/// The i-th data operand passed.
12541257
Value getDataOperand(unsigned i);
1258+
1259+
/// Return true if the op has the async attribute for the
1260+
/// mlir::acc::DeviceType::None device_type.
1261+
bool hasAsyncOnly();
1262+
/// Return true if the op has the async attribute for the given device_type.
1263+
bool hasAsyncOnly(mlir::acc::DeviceType deviceType);
1264+
/// Return the value of the async clause if present.
1265+
mlir::Value getAsyncValue();
1266+
/// Return the value of the async clause for the given device_type if
1267+
/// present.
1268+
mlir::Value getAsyncValue(mlir::acc::DeviceType deviceType);
1269+
1270+
/// Return true if the op has the wait attribute for the
1271+
/// mlir::acc::DeviceType::None device_type.
1272+
bool hasWaitOnly();
1273+
/// Return true if the op has the wait attribute for the given device_type.
1274+
bool hasWaitOnly(mlir::acc::DeviceType deviceType);
1275+
/// Return the values of the wait clause if present.
1276+
mlir::Operation::operand_range getWaitValues();
1277+
/// Return the values of the wait clause for the given device_type if
1278+
/// present.
1279+
mlir::Operation::operand_range
1280+
getWaitValues(mlir::acc::DeviceType deviceType);
12551281
}];
12561282

12571283
let assemblyFormat = [{
12581284
oilist(
12591285
`if` `(` $ifCond `)`
1260-
| `async` `(` $async `:` type($async) `)`
1286+
| `async` `(` custom<DeviceTypeOperands>($async,
1287+
type($async), $asyncDeviceType) `)`
12611288
| `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
12621289
| `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
1263-
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
1290+
| `wait` `(` custom<DeviceTypeOperandsWithSegment>($waitOperands,
1291+
type($waitOperands), $waitOperandsDeviceType,
1292+
$waitOperandsSegments) `)`
12641293
)
12651294
$region attr-dict-with-keyword
12661295
}];

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1417,11 +1417,52 @@ unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
14171417

14181418
Value DataOp::getDataOperand(unsigned i) {
14191419
unsigned numOptional = getIfCond() ? 1 : 0;
1420-
numOptional += getAsync() ? 1 : 0;
1420+
numOptional += getAsync().size() ? 1 : 0;
14211421
numOptional += getWaitOperands().size();
14221422
return getOperand(numOptional + i);
14231423
}
14241424

1425+
bool acc::DataOp::hasAsyncOnly() {
1426+
return hasAsyncOnly(mlir::acc::DeviceType::None);
1427+
}
1428+
1429+
bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1430+
if (auto arrayAttr = getAsyncOnly()) {
1431+
if (findSegment(*arrayAttr, deviceType))
1432+
return true;
1433+
}
1434+
return false;
1435+
}
1436+
1437+
mlir::Value DataOp::getAsyncValue() {
1438+
return getAsyncValue(mlir::acc::DeviceType::None);
1439+
}
1440+
1441+
mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1442+
return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(),
1443+
deviceType);
1444+
}
1445+
1446+
bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
1447+
1448+
bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1449+
if (auto arrayAttr = getWaitOnly()) {
1450+
if (findSegment(*arrayAttr, deviceType))
1451+
return true;
1452+
}
1453+
return false;
1454+
}
1455+
1456+
mlir::Operation::operand_range DataOp::getWaitValues() {
1457+
return getWaitValues(mlir::acc::DeviceType::None);
1458+
}
1459+
1460+
mlir::Operation::operand_range
1461+
DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1462+
return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(),
1463+
getWaitOperandsSegments(), deviceType);
1464+
}
1465+
14251466
//===----------------------------------------------------------------------===//
14261467
// ExitDataOp
14271468
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenACC/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -836,11 +836,11 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
836836
} attributes { defaultAttr = #acc<defaultvalue none>, wait }
837837

838838
%w1 = arith.constant 1 : i64
839-
acc.data wait(%w1 : i64) {
839+
acc.data wait({%w1 : i64}) {
840840
} attributes { defaultAttr = #acc<defaultvalue none>, wait }
841841

842842
%wd1 = arith.constant 1 : i64
843-
acc.data wait_devnum(%wd1 : i64) wait(%w1 : i64) {
843+
acc.data wait_devnum(%wd1 : i64) wait({%w1 : i64}) {
844844
} attributes { defaultAttr = #acc<defaultvalue none>, wait }
845845

846846
return
@@ -951,10 +951,10 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
951951
// CHECK: acc.data {
952952
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
953953

954-
// CHECK: acc.data wait(%{{.*}} : i64) {
954+
// CHECK: acc.data wait({%{{.*}} : i64}) {
955955
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
956956

957-
// CHECK: acc.data wait_devnum(%{{.*}} : i64) wait(%{{.*}} : i64) {
957+
// CHECK: acc.data wait_devnum(%{{.*}} : i64) wait({%{{.*}} : i64}) {
958958
// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
959959

960960
// -----

0 commit comments

Comments
 (0)