Skip to content

[flang][MLIR][OpenMP] Add support for target update directive. #75047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,56 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
let hasVerifier = 1;
}

//===---------------------------------------------------------------------===//
// 2.14.6 target update data Construct
//===---------------------------------------------------------------------===//

def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
[AttrSizedOperandSegments]>{
let summary = "target update data construct";
let description = [{
The target update directive makes the corresponding list items in the device
data environment consistent with their original list items, according to the
specified motion clauses. The target update construct is a stand-alone
directive.

The optional $if_expr parameter specifies a boolean result of a
conditional check. If this value is 1 or is not provided then the target
region runs on a device, if it is 0 then the target region is executed
on the host device.

The optional $device parameter specifies the device number for the
target region.

The optional $nowait eliminates the implicit barrier so the parent
task can make progress even if the target task is not yet completed.

We use `MapInfoOp` to model the motion clauses and their modifiers. Even
though the spec differentiates between map-types & map-type-modifiers vs.
motion-clauses & motion-modifiers, the motion clauses and their modifiers are
a subset of map types and their modifiers. The subset relation is handled in
during verification to make sure the restrictions for target update are
respected.

TODO: depend clause
}];

let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$motion_operands);

let assemblyFormat = [{
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `motion_entries` `(` $motion_operands `:` type($motion_operands) `)`
) attr-dict
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// 2.14.5 target construct
//===----------------------------------------------------------------------===//
Expand Down
53 changes: 48 additions & 5 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -882,21 +882,22 @@ static ParseResult parseCaptureType(OpAsmParser &parser,
}

static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateToVars;
llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateFromVars;

for (auto mapOp : mapOperands) {
if (!mapOp.getDefiningOp())
emitError(op->getLoc(), "missing map operation");

if (auto MapInfoOp =
if (auto mapInfoOp =
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {

if (!MapInfoOp.getMapType().has_value())
if (!mapInfoOp.getMapType().has_value())
emitError(op->getLoc(), "missing map type for map operand");

if (!MapInfoOp.getMapCaptureType().has_value())
if (!mapInfoOp.getMapCaptureType().has_value())
emitError(op->getLoc(), "missing map capture type for map operand");

uint64_t mapTypeBits = MapInfoOp.getMapType().value();
uint64_t mapTypeBits = mapInfoOp.getMapType().value();

bool to = mapTypeToBitFlag(
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
Expand All @@ -905,6 +906,13 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
bool del = mapTypeToBitFlag(
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);

bool always = mapTypeToBitFlag(
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
bool close = mapTypeToBitFlag(
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
bool implicit = mapTypeToBitFlag(
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);

if ((isa<DataOp>(op) || isa<TargetOp>(op)) && del)
return emitError(op->getLoc(),
"to, from, tofrom and alloc map types are permitted");
Expand All @@ -915,6 +923,37 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
if (isa<ExitDataOp>(op) && to)
return emitError(op->getLoc(),
"from, release and delete map types are permitted");

if (isa<UpdateDataOp>(op)) {
if (del) {
return emitError(op->getLoc(),
"at least one of to or from map types must be "
"specified, other map types are not permitted");
}

if (!to && !from) {
return emitError(op->getLoc(),
"at least one of to or from map types must be "
"specified, other map types are not permitted");
}

auto updateVar = mapInfoOp.getVarPtr();

if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
(from && updateToVars.contains(updateVar))) {
return emitError(
op->getLoc(),
"either to or from map types can be specified, not both");
}

if (always || close || implicit) {
return emitError(
op->getLoc(),
"present, mapper and iterator map type modifiers are permitted");
}

to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
}
} else {
emitError(op->getLoc(), "map argument is not a map entry operation");
}
Expand All @@ -940,6 +979,10 @@ LogicalResult ExitDataOp::verify() {
return verifyMapClause(*this, getMapOperands());
}

LogicalResult UpdateDataOp::verify() {
return verifyMapClause(*this, getMotionOperands());
}

LogicalResult TargetOp::verify() {
return verifyMapClause(*this, getMapOperands());
}
Expand Down
70 changes: 70 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1658,4 +1658,74 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
return
}

// -----

func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}

// expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
return
}

// -----

func.func @omp_target_update_invalid_motion_type_2(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(delete) capture(ByRef) -> memref<?xi32> {name = ""}

// expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
return
}

// -----

func.func @omp_target_update_invalid_motion_modifier(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}

// expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
return
}

// -----

func.func @omp_target_update_invalid_motion_modifier_2(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(close, to) capture(ByRef) -> memref<?xi32> {name = ""}

// expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
return
}

// -----

func.func @omp_target_update_invalid_motion_modifier_3(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(implicit, to) capture(ByRef) -> memref<?xi32> {name = ""}

// expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
return
}

// -----

func.func @omp_target_update_invalid_motion_modifier_4(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(implicit, tofrom) capture(ByRef) -> memref<?xi32> {name = ""}

// expected-error @below {{either to or from map types can be specified, not both}}
omp.target_update_data motion_entries(%mapv : memref<?xi32>)
return
}

// -----

func.func @omp_target_update_invalid_motion_modifier_5(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
%mapv2 = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> {name = ""}

// expected-error @below {{either to or from map types can be specified, not both}}
omp.target_update_data motion_entries(%mapv, %mapv2 : memref<?xi32>, memref<?xi32>)
return
}
llvm.mlir.global internal @_QFsubEx() : i32
12 changes: 12 additions & 0 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2082,3 +2082,15 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()

return
}

// CHECK-LABEL: omp_target_update_data
func.func @omp_target_update_data (%if_cond : i1, %device : si32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
%mapv_from = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> {name = ""}

%mapv_to = omp.map_info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(present, to) capture(ByRef) -> memref<?xi32> {name = ""}

// CHECK: omp.target_update_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) nowait motion_entries(%{{.*}}, %{{.*}} : memref<?xi32>, memref<?xi32>)
omp.target_update_data if(%if_cond : i1) device(%device : si32) nowait motion_entries(%mapv_from , %mapv_to : memref<?xi32>, memref<?xi32>)
return
}