-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][OpenMP][Offload] Lower target update op to DeviceRT #75159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-openmp @llvm/pr-subscribers-mlir Author: Kareem Ergawy (ergawy) ChangesAdds support for lowring This is a follow-up to #75047 which is yet to be merged, only the last commit Full diff: https://github.com/llvm/llvm-project/pull/75159.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8ff5380f71ad45..b9989b335a2aef 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1370,6 +1370,56 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
let hasVerifier = 1;
}
+//===---------------------------------------------------------------------===//
+// 2.14.6 target update data Construct
+//===---------------------------------------------------------------------===//
+
+def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
+ [AttrSizedOperandSegments]>{
+ let summary = "target update data construct";
+ let description = [{
+ The target update directive makes the corresponding list items in the device
+ data environment consistent with their original list items, according to the
+ specified motion clauses. The target update construct is a stand-alone
+ directive.
+
+ The optional $if_expr parameter specifies a boolean result of a
+ conditional check. If this value is 1 or is not provided then the target
+ region runs on a device, if it is 0 then the target region is executed
+ on the host device.
+
+ The optional $device parameter specifies the device number for the
+ target region.
+
+ The optional $nowait eliminates the implicit barrier so the parent
+ task can make progress even if the target task is not yet completed.
+
+ We use `MapInfoOp` to model the motion clauses and their modifiers. Even
+ though the spec differentiates between map-types & map-type-modifiers vs.
+ motion-clauses & motion-modifiers, the motion clauses and their modifiers are
+ a subset of map types and their modifiers. The subset relation is handled in
+ during verification to make sure the restrictions for target update are
+ respected.
+
+ TODO: depend clause
+ }];
+
+ let arguments = (ins Optional<I1>:$if_expr,
+ Optional<AnyInteger>:$device,
+ UnitAttr:$nowait,
+ Variadic<OpenMP_PointerLikeType>:$motion_operands);
+
+ let assemblyFormat = [{
+ oilist(`if` `(` $if_expr `:` type($if_expr) `)`
+ | `device` `(` $device `:` type($device) `)`
+ | `nowait` $nowait
+ | `motion_entries` `(` $motion_operands `:` type($motion_operands) `)`
+ ) attr-dict
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// 2.14.5 target construct
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 20df0099cbd24d..df98cde46877d1 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -691,6 +691,7 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
if (parser.parseKeyword(&mapTypeMod))
return failure();
+ // Map-type-modifiers
if (mapTypeMod == "always")
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
@@ -703,6 +704,7 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
if (mapTypeMod == "present")
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
+ // Map-types
if (mapTypeMod == "to")
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
@@ -882,6 +884,7 @@ static ParseResult parseCaptureType(OpAsmParser &parser,
}
static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
+ bool foundRequiredMapTypes = false;
for (auto mapOp : mapOperands) {
if (!mapOp.getDefiningOp())
@@ -898,6 +901,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
uint64_t mapTypeBits = MapInfoOp.getMapType().value();
+ // Map-types
bool to = mapTypeToBitFlag(
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
bool from = mapTypeToBitFlag(
@@ -905,6 +909,14 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
bool del = mapTypeToBitFlag(
mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
+ // Map-type-modifiers
+ bool always = mapTypeToBitFlag(
+ mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
+ bool close = mapTypeToBitFlag(
+ mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
+ bool implicit = mapTypeToBitFlag(
+ mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
+
if ((isa<DataOp>(op) || isa<TargetOp>(op)) && del)
return emitError(op->getLoc(),
"to, from, tofrom and alloc map types are permitted");
@@ -915,11 +927,38 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
if (isa<ExitDataOp>(op) && to)
return emitError(op->getLoc(),
"from, release and delete map types are permitted");
+
+ if (isa<UpdateDataOp>(op)) {
+ if (del) {
+ return emitError(op->getLoc(),
+ "at least one of to or from map types must be "
+ "specified, other map types are not permitted");
+ }
+
+ if (to | from) {
+ foundRequiredMapTypes = true;
+ }
+ }
+
+ // Check UpdateDataOp's valid map-type-modifiers.
+ if (isa<UpdateDataOp>(op) && (always | close | implicit)) {
+ return emitError(
+ op->getLoc(),
+ "present, mapper and iterator map type modifiers are permitted");
+ }
} else {
emitError(op->getLoc(), "map argument is not a map entry operation");
}
}
+ if (isa<UpdateDataOp>(op)) {
+ if (!foundRequiredMapTypes) {
+ return emitError(op->getLoc(),
+ "at least one of to or from map types must be "
+ "specified, other map types are not permitted");
+ }
+ }
+
return success();
}
@@ -940,6 +979,10 @@ LogicalResult ExitDataOp::verify() {
return verifyMapClause(*this, getMapOperands());
}
+LogicalResult UpdateDataOp::verify() {
+ return verifyMapClause(*this, getMotionOperands());
+}
+
LogicalResult TargetOp::verify() {
return verifyMapClause(*this, getMapOperands());
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4f6200d29a70a6..088e7ae4231bef 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1915,6 +1915,23 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
mapOperands = exitDataOp.getMapOperands();
return success();
})
+ .Case([&](omp::UpdateDataOp updateDataOp) {
+ if (updateDataOp.getNowait())
+ return failure();
+
+ if (auto ifExprVar = updateDataOp.getIfExpr())
+ ifCond = moduleTranslation.lookupValue(ifExprVar);
+
+ if (auto devId = updateDataOp.getDevice())
+ if (auto constOp =
+ dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
+ deviceID = intAttr.getInt();
+
+ RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
+ mapOperands = updateDataOp.getMotionOperands();
+ return success();
+ })
.Default([&](Operation *op) {
return op->emitError("unsupported OpenMP operation: ")
<< op->getName();
@@ -2748,9 +2765,10 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
.Case([&](omp::ThreadprivateOp) {
return convertOmpThreadprivate(*op, builder, moduleTranslation);
})
- .Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp>([&](auto op) {
- return convertOmpTargetData(op, builder, moduleTranslation);
- })
+ .Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp, omp::UpdateDataOp>(
+ [&](auto op) {
+ return convertOmpTargetData(op, builder, moduleTranslation);
+ })
.Case([&](omp::TargetOp) {
return convertOmpTarget(*op, builder, moduleTranslation);
})
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e54808f6cfdee5..aace0241686369 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1658,4 +1658,77 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
return
}
+// -----
+
+func.func @omp_target_update_data_if(%if_cond : i1) {
+ // expected-error @below {{`if` clause can appear at most once in the expansion of the oilist directive}}
+ omp.target_update_data if(%if_cond : i1) if(%if_cond : i1)
+ return
+}
+
+// -----
+
+func.func @omp_target_update_data_device(%device : si32) {
+ // expected-error @below {{`device` clause can appear at most once in the expansion of the oilist directive}}
+ omp.target_update_data device(%device : si32) device(%device : si32)
+ return
+}
+
+// -----
+
+func.func @omp_target_update_data_nowait() {
+ // expected-error @below {{`nowait` clause can appear at most once in the expansion of the oilist directive}}
+ omp.target_update_data nowait nowait
+ return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
+ %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ // expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
+ omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+ return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_type_2(%map1 : memref<?xi32>) {
+ %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(delete) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ // expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
+ omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+ return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier(%map1 : memref<?xi32>) {
+ %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ // expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
+ omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+ return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier_2(%map1 : memref<?xi32>) {
+ %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(close, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ // expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
+ omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+ return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier_3(%map1 : memref<?xi32>) {
+ %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(implicit, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ // expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
+ omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+ return
+}
llvm.mlir.global internal @_QFsubEx() : i32
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4d88d9ac86fe16..b0a6a5ac0a3fb9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2082,3 +2082,15 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
return
}
+
+// CHECK-LABEL: omp_target_update_data
+func.func @omp_target_update_data (%if_cond : i1, %device : si32, %map1: memref<?xi32>) -> () {
+ %mapv_from = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ %mapv_to = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(present, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ // CHECK: omp.target_update_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) nowait motion_entries(%{{.*}}, %{{.*}} : memref<?xi32>, memref<?xi32>)
+ omp.target_update_data if(%if_cond : i1) device(%device : si32) nowait motion_entries(%mapv_from , %mapv_to : memref<?xi32>, memref<?xi32>)
+ return
+}
+
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index 9221b410d766ed..acb56e061f4470 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -441,3 +441,22 @@ llvm.func @_QPopenmp_target_use_dev_both() {
// CHECK: ret void
// -----
+
+llvm.func @_QPopenmp_target_data_update() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
+ %2 = omp.map_info var_ptr(%1 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+ omp.target_data map_entries(%2 : !llvm.ptr) {
+ %3 = llvm.mlir.constant(99 : i32) : i32
+ llvm.store %3, %1 : i32, !llvm.ptr
+ omp.target_update_data motion_entries(%2 : !llvm.ptr)
+ omp.terminator
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: define void @_QPopenmp_target_data_update
+// CHECK: call void @__tgt_target_data_begin_mapper
+// CHECK: call void @__tgt_target_data_update_mapper
+// CHECK: call void @__tgt_target_data_end_mapper
+// CHECK: ret void
|
f3baeac
to
c52e956
Compare
c52e956
to
78ec927
Compare
0838a3a
to
56756c2
Compare
bc18e5b
to
3b3e078
Compare
@@ -441,3 +441,41 @@ llvm.func @_QPopenmp_target_use_dev_both() { | |||
// CHECK: ret void | |||
|
|||
// ----- | |||
|
|||
llvm.func @_QPopenmp_target_data_update() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add more tests to cover different data types (like Fortran pointers, allocatable arrays etc.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, the FIR dialect is not registered with mlir-translate
. So we cannot use fir.ref
in omptarget-llvm.mlir
.
However, I will double check with @agozillon whether we have existing tests for lowering map_info
op's arguments from fir
to llvm
types or not. If not, I will open another PR with more testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks for the patch!
Please make sure build bots are clean before merging 👍🏽
Adds support for lowring `UpdateDataOp` to the DeviceRT. This reuses the existing utils used by other device directive.
3b3e078
to
f327f70
Compare
Adds support for lowring
UpdateDataOp
to the DeviceRT. This reuses theexisting utils used by other device directive.