Skip to content

Commit 2ab926d

Browse files
authored
[flang][MLIR][OpenMP] Add support for target update directive. (#75047)
Add an op in the OMP dialect to model the `target update` direcive. This change reuses the `MapInfoOp` used by other device directive to model `map` clauses but verifies that the restrictions imposed by the `target update` directive are respected.
1 parent 101083e commit 2ab926d

File tree

4 files changed

+180
-5
lines changed

4 files changed

+180
-5
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,6 +1370,56 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
13701370
let hasVerifier = 1;
13711371
}
13721372

1373+
//===---------------------------------------------------------------------===//
1374+
// 2.14.6 target update data Construct
1375+
//===---------------------------------------------------------------------===//
1376+
1377+
def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
1378+
[AttrSizedOperandSegments]>{
1379+
let summary = "target update data construct";
1380+
let description = [{
1381+
The target update directive makes the corresponding list items in the device
1382+
data environment consistent with their original list items, according to the
1383+
specified motion clauses. The target update construct is a stand-alone
1384+
directive.
1385+
1386+
The optional $if_expr parameter specifies a boolean result of a
1387+
conditional check. If this value is 1 or is not provided then the target
1388+
region runs on a device, if it is 0 then the target region is executed
1389+
on the host device.
1390+
1391+
The optional $device parameter specifies the device number for the
1392+
target region.
1393+
1394+
The optional $nowait eliminates the implicit barrier so the parent
1395+
task can make progress even if the target task is not yet completed.
1396+
1397+
We use `MapInfoOp` to model the motion clauses and their modifiers. Even
1398+
though the spec differentiates between map-types & map-type-modifiers vs.
1399+
motion-clauses & motion-modifiers, the motion clauses and their modifiers are
1400+
a subset of map types and their modifiers. The subset relation is handled in
1401+
during verification to make sure the restrictions for target update are
1402+
respected.
1403+
1404+
TODO: depend clause
1405+
}];
1406+
1407+
let arguments = (ins Optional<I1>:$if_expr,
1408+
Optional<AnyInteger>:$device,
1409+
UnitAttr:$nowait,
1410+
Variadic<OpenMP_PointerLikeType>:$motion_operands);
1411+
1412+
let assemblyFormat = [{
1413+
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
1414+
| `device` `(` $device `:` type($device) `)`
1415+
| `nowait` $nowait
1416+
| `motion_entries` `(` $motion_operands `:` type($motion_operands) `)`
1417+
) attr-dict
1418+
}];
1419+
1420+
let hasVerifier = 1;
1421+
}
1422+
13731423
//===----------------------------------------------------------------------===//
13741424
// 2.14.5 target construct
13751425
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -882,21 +882,22 @@ static ParseResult parseCaptureType(OpAsmParser &parser,
882882
}
883883

884884
static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
885+
llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateToVars;
886+
llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateFromVars;
885887

886888
for (auto mapOp : mapOperands) {
887889
if (!mapOp.getDefiningOp())
888890
emitError(op->getLoc(), "missing map operation");
889891

890-
if (auto MapInfoOp =
892+
if (auto mapInfoOp =
891893
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
892-
893-
if (!MapInfoOp.getMapType().has_value())
894+
if (!mapInfoOp.getMapType().has_value())
894895
emitError(op->getLoc(), "missing map type for map operand");
895896

896-
if (!MapInfoOp.getMapCaptureType().has_value())
897+
if (!mapInfoOp.getMapCaptureType().has_value())
897898
emitError(op->getLoc(), "missing map capture type for map operand");
898899

899-
uint64_t mapTypeBits = MapInfoOp.getMapType().value();
900+
uint64_t mapTypeBits = mapInfoOp.getMapType().value();
900901

901902
bool to = mapTypeToBitFlag(
902903
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
@@ -905,6 +906,13 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
905906
bool del = mapTypeToBitFlag(
906907
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
907908

909+
bool always = mapTypeToBitFlag(
910+
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
911+
bool close = mapTypeToBitFlag(
912+
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
913+
bool implicit = mapTypeToBitFlag(
914+
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
915+
908916
if ((isa<DataOp>(op) || isa<TargetOp>(op)) && del)
909917
return emitError(op->getLoc(),
910918
"to, from, tofrom and alloc map types are permitted");
@@ -915,6 +923,37 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
915923
if (isa<ExitDataOp>(op) && to)
916924
return emitError(op->getLoc(),
917925
"from, release and delete map types are permitted");
926+
927+
if (isa<UpdateDataOp>(op)) {
928+
if (del) {
929+
return emitError(op->getLoc(),
930+
"at least one of to or from map types must be "
931+
"specified, other map types are not permitted");
932+
}
933+
934+
if (!to && !from) {
935+
return emitError(op->getLoc(),
936+
"at least one of to or from map types must be "
937+
"specified, other map types are not permitted");
938+
}
939+
940+
auto updateVar = mapInfoOp.getVarPtr();
941+
942+
if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
943+
(from && updateToVars.contains(updateVar))) {
944+
return emitError(
945+
op->getLoc(),
946+
"either to or from map types can be specified, not both");
947+
}
948+
949+
if (always || close || implicit) {
950+
return emitError(
951+
op->getLoc(),
952+
"present, mapper and iterator map type modifiers are permitted");
953+
}
954+
955+
to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
956+
}
918957
} else {
919958
emitError(op->getLoc(), "map argument is not a map entry operation");
920959
}
@@ -940,6 +979,10 @@ LogicalResult ExitDataOp::verify() {
940979
return verifyMapClause(*this, getMapOperands());
941980
}
942981

982+
LogicalResult UpdateDataOp::verify() {
983+
return verifyMapClause(*this, getMotionOperands());
984+
}
985+
943986
LogicalResult TargetOp::verify() {
944987
return verifyMapClause(*this, getMapOperands());
945988
}

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,4 +1658,74 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
16581658
return
16591659
}
16601660

1661+
// -----
1662+
1663+
func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
1664+
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
1665+
1666+
// expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
1667+
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
1668+
return
1669+
}
1670+
1671+
// -----
1672+
1673+
func.func @omp_target_update_invalid_motion_type_2(%map1 : memref<?xi32>) {
1674+
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(delete) capture(ByRef) -> memref<?xi32> {name = ""}
1675+
1676+
// expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
1677+
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
1678+
return
1679+
}
1680+
1681+
// -----
1682+
1683+
func.func @omp_target_update_invalid_motion_modifier(%map1 : memref<?xi32>) {
1684+
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
1685+
1686+
// expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
1687+
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
1688+
return
1689+
}
1690+
1691+
// -----
1692+
1693+
func.func @omp_target_update_invalid_motion_modifier_2(%map1 : memref<?xi32>) {
1694+
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(close, to) capture(ByRef) -> memref<?xi32> {name = ""}
1695+
1696+
// expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
1697+
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
1698+
return
1699+
}
1700+
1701+
// -----
1702+
1703+
func.func @omp_target_update_invalid_motion_modifier_3(%map1 : memref<?xi32>) {
1704+
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(implicit, to) capture(ByRef) -> memref<?xi32> {name = ""}
1705+
1706+
// expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
1707+
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
1708+
return
1709+
}
1710+
1711+
// -----
1712+
1713+
func.func @omp_target_update_invalid_motion_modifier_4(%map1 : memref<?xi32>) {
1714+
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(implicit, tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
1715+
1716+
// expected-error @below {{either to or from map types can be specified, not both}}
1717+
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
1718+
return
1719+
}
1720+
1721+
// -----
1722+
1723+
func.func @omp_target_update_invalid_motion_modifier_5(%map1 : memref<?xi32>) {
1724+
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
1725+
%mapv2 = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> {name = ""}
1726+
1727+
// expected-error @below {{either to or from map types can be specified, not both}}
1728+
omp.target_update_data motion_entries(%mapv, %mapv2 : memref<?xi32>, memref<?xi32>)
1729+
return
1730+
}
16611731
llvm.mlir.global internal @_QFsubEx() : i32

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,3 +2082,15 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
20822082

20832083
return
20842084
}
2085+
2086+
// CHECK-LABEL: omp_target_update_data
2087+
func.func @omp_target_update_data (%if_cond : i1, %device : si32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
2088+
%mapv_from = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> {name = ""}
2089+
2090+
%mapv_to = omp.map_info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(present, to) capture(ByRef) -> memref<?xi32> {name = ""}
2091+
2092+
// CHECK: omp.target_update_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) nowait motion_entries(%{{.*}}, %{{.*}} : memref<?xi32>, memref<?xi32>)
2093+
omp.target_update_data if(%if_cond : i1) device(%device : si32) nowait motion_entries(%mapv_from , %mapv_to : memref<?xi32>, memref<?xi32>)
2094+
return
2095+
}
2096+

0 commit comments

Comments
 (0)