Skip to content

[MLIR][OpenMP][Offload] Lower target update op to DeviceRT #75159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2023

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Dec 12, 2023

Adds support for lowring UpdateDataOp to the DeviceRT. This reuses the
existing utils used by other device directive.

@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2023

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Kareem Ergawy (ergawy)

Changes

Adds support for lowring UpdateDataOp to the DeviceRT. This reuses the
existing utils used by other device directive.

This is a follow-up to #75047 which is yet to be merged, only the last commit
is part of this PR.


Full diff: https://github.com/llvm/llvm-project/pull/75159.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+50)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+43)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+21-3)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+73)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+12)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+19)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8ff5380f71ad45..b9989b335a2aef 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1370,6 +1370,56 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
   let hasVerifier = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// 2.14.6 target update data Construct
+//===---------------------------------------------------------------------===//
+
+def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
+                                                 [AttrSizedOperandSegments]>{
+  let  summary = "target update data construct";
+  let description = [{
+    The target update directive makes the corresponding list items in the device
+    data environment consistent with their original list items, according to the
+    specified motion clauses. The target update construct is a stand-alone
+    directive.
+
+    The optional $if_expr parameter specifies a boolean result of a
+    conditional check. If this value is 1 or is not provided then the target
+    region runs on a device, if it is 0 then the target region is executed
+    on the host device.
+
+    The optional $device parameter specifies the device number for the
+    target region.
+
+    The optional $nowait eliminates the implicit barrier so the parent
+    task can make progress even if the target task is not yet completed.
+
+    We use `MapInfoOp` to model the motion clauses and their modifiers. Even
+    though the spec differentiates between map-types & map-type-modifiers vs.
+    motion-clauses & motion-modifiers, the motion clauses and their modifiers are
+    a subset of map types and their modifiers. The subset relation is handled in
+    during verification to make sure the restrictions for target update are
+    respected.
+
+    TODO: depend clause
+  }];
+
+  let arguments = (ins Optional<I1>:$if_expr,
+                       Optional<AnyInteger>:$device,
+                       UnitAttr:$nowait,
+                       Variadic<OpenMP_PointerLikeType>:$motion_operands);
+
+  let assemblyFormat = [{
+    oilist(`if` `(` $if_expr `:` type($if_expr) `)`
+    | `device` `(` $device `:` type($device) `)`
+    | `nowait` $nowait
+    | `motion_entries` `(` $motion_operands `:` type($motion_operands) `)`
+    ) attr-dict
+   }];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // 2.14.5 target construct
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 20df0099cbd24d..df98cde46877d1 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -691,6 +691,7 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
     if (parser.parseKeyword(&mapTypeMod))
       return failure();
 
+    // Map-type-modifiers
     if (mapTypeMod == "always")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
 
@@ -703,6 +704,7 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
     if (mapTypeMod == "present")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
 
+    // Map-types
     if (mapTypeMod == "to")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
 
@@ -882,6 +884,7 @@ static ParseResult parseCaptureType(OpAsmParser &parser,
 }
 
 static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
+  bool foundRequiredMapTypes = false;
 
   for (auto mapOp : mapOperands) {
     if (!mapOp.getDefiningOp())
@@ -898,6 +901,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
 
       uint64_t mapTypeBits = MapInfoOp.getMapType().value();
 
+      // Map-types
       bool to = mapTypeToBitFlag(
           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
       bool from = mapTypeToBitFlag(
@@ -905,6 +909,14 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
       bool del = mapTypeToBitFlag(
           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
 
+      // Map-type-modifiers
+      bool always = mapTypeToBitFlag(
+          mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
+      bool close = mapTypeToBitFlag(
+          mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
+      bool implicit = mapTypeToBitFlag(
+          mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
+
       if ((isa<DataOp>(op) || isa<TargetOp>(op)) && del)
         return emitError(op->getLoc(),
                          "to, from, tofrom and alloc map types are permitted");
@@ -915,11 +927,38 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
       if (isa<ExitDataOp>(op) && to)
         return emitError(op->getLoc(),
                          "from, release and delete map types are permitted");
+
+      if (isa<UpdateDataOp>(op)) {
+        if (del) {
+          return emitError(op->getLoc(),
+                           "at least one of to or from map types must be "
+                           "specified, other map types are not permitted");
+        }
+
+        if (to | from) {
+          foundRequiredMapTypes = true;
+        }
+      }
+
+      // Check UpdateDataOp's valid map-type-modifiers.
+      if (isa<UpdateDataOp>(op) && (always | close | implicit)) {
+        return emitError(
+            op->getLoc(),
+            "present, mapper and iterator map type modifiers are permitted");
+      }
     } else {
       emitError(op->getLoc(), "map argument is not a map entry operation");
     }
   }
 
+  if (isa<UpdateDataOp>(op)) {
+    if (!foundRequiredMapTypes) {
+      return emitError(op->getLoc(),
+                       "at least one of to or from map types must be "
+                       "specified, other map types are not permitted");
+    }
+  }
+
   return success();
 }
 
@@ -940,6 +979,10 @@ LogicalResult ExitDataOp::verify() {
   return verifyMapClause(*this, getMapOperands());
 }
 
+LogicalResult UpdateDataOp::verify() {
+  return verifyMapClause(*this, getMotionOperands());
+}
+
 LogicalResult TargetOp::verify() {
   return verifyMapClause(*this, getMapOperands());
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4f6200d29a70a6..088e7ae4231bef 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1915,6 +1915,23 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             mapOperands = exitDataOp.getMapOperands();
             return success();
           })
+          .Case([&](omp::UpdateDataOp updateDataOp) {
+            if (updateDataOp.getNowait())
+              return failure();
+
+            if (auto ifExprVar = updateDataOp.getIfExpr())
+              ifCond = moduleTranslation.lookupValue(ifExprVar);
+
+            if (auto devId = updateDataOp.getDevice())
+              if (auto constOp =
+                      dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
+                  deviceID = intAttr.getInt();
+
+            RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
+            mapOperands = updateDataOp.getMotionOperands();
+            return success();
+          })
           .Default([&](Operation *op) {
             return op->emitError("unsupported OpenMP operation: ")
                    << op->getName();
@@ -2748,9 +2765,10 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
       .Case([&](omp::ThreadprivateOp) {
         return convertOmpThreadprivate(*op, builder, moduleTranslation);
       })
-      .Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp>([&](auto op) {
-        return convertOmpTargetData(op, builder, moduleTranslation);
-      })
+      .Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp, omp::UpdateDataOp>(
+          [&](auto op) {
+            return convertOmpTargetData(op, builder, moduleTranslation);
+          })
       .Case([&](omp::TargetOp) {
         return convertOmpTarget(*op, builder, moduleTranslation);
       })
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e54808f6cfdee5..aace0241686369 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1658,4 +1658,77 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
   return
 }
 
+// -----
+
+func.func @omp_target_update_data_if(%if_cond : i1) {
+  // expected-error @below {{`if` clause can appear at most once in the expansion of the oilist directive}}
+  omp.target_update_data if(%if_cond : i1) if(%if_cond : i1)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_data_device(%device : si32) {
+  // expected-error @below {{`device` clause can appear at most once in the expansion of the oilist directive}}
+  omp.target_update_data device(%device : si32) device(%device : si32)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_data_nowait() {
+  // expected-error @below {{`nowait` clause can appear at most once in the expansion of the oilist directive}}
+  omp.target_update_data nowait nowait
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_type_2(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(delete) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier_2(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(close, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier_3(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(implicit, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
 llvm.mlir.global internal @_QFsubEx() : i32
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4d88d9ac86fe16..b0a6a5ac0a3fb9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2082,3 +2082,15 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
 
     return
 }
+
+// CHECK-LABEL: omp_target_update_data
+func.func @omp_target_update_data (%if_cond : i1, %device : si32, %map1: memref<?xi32>) -> () {
+    %mapv_from = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> {name = ""}
+
+    %mapv_to = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(present, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+    // CHECK: omp.target_update_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) nowait motion_entries(%{{.*}}, %{{.*}} : memref<?xi32>, memref<?xi32>)
+    omp.target_update_data if(%if_cond : i1) device(%device : si32) nowait motion_entries(%mapv_from , %mapv_to : memref<?xi32>, memref<?xi32>)
+    return
+}
+
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index 9221b410d766ed..acb56e061f4470 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -441,3 +441,22 @@ llvm.func @_QPopenmp_target_use_dev_both() {
 // CHECK:         ret void
 
 // -----
+
+llvm.func @_QPopenmp_target_data_update() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
+  %2 = omp.map_info var_ptr(%1 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+  omp.target_data map_entries(%2 : !llvm.ptr) {
+    %3 = llvm.mlir.constant(99 : i32) : i32
+    llvm.store %3, %1 : i32, !llvm.ptr
+    omp.target_update_data motion_entries(%2 : !llvm.ptr)
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: define void @_QPopenmp_target_data_update
+// CHECK:         call void @__tgt_target_data_begin_mapper
+// CHECK:         call void @__tgt_target_data_update_mapper
+// CHECK:         call void @__tgt_target_data_end_mapper
+// CHECK:         ret void

@ergawy ergawy force-pushed the omp_target_upate_2 branch from f3baeac to c52e956 Compare December 12, 2023 11:03
@ergawy ergawy marked this pull request as draft December 12, 2023 13:27
@ergawy ergawy force-pushed the omp_target_upate_2 branch from c52e956 to 78ec927 Compare December 12, 2023 14:45
@ergawy ergawy marked this pull request as ready for review December 12, 2023 15:08
@ergawy ergawy force-pushed the omp_target_upate_2 branch 2 times, most recently from 0838a3a to 56756c2 Compare December 13, 2023 04:47
@ergawy ergawy force-pushed the omp_target_upate_2 branch 2 times, most recently from bc18e5b to 3b3e078 Compare December 14, 2023 05:50
@@ -441,3 +441,41 @@ llvm.func @_QPopenmp_target_use_dev_both() {
// CHECK: ret void

// -----

llvm.func @_QPopenmp_target_data_update() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more tests to cover different data types (like Fortran pointers, allocatable arrays etc.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, the FIR dialect is not registered with mlir-translate. So we cannot use fir.ref in omptarget-llvm.mlir.

However, I will double check with @agozillon whether we have existing tests for lowering map_info op's arguments from fir to llvm types or not. If not, I will open another PR with more testing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, LGTM

Copy link
Member

@TIFitis TIFitis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks for the patch!
Please make sure build bots are clean before merging 👍🏽

Adds support for lowring `UpdateDataOp` to the DeviceRT. This reuses the
existing utils used by other device directive.
@ergawy ergawy force-pushed the omp_target_upate_2 branch from 3b3e078 to f327f70 Compare December 15, 2023 14:22
@ergawy ergawy merged commit d777504 into llvm:main Dec 18, 2023
ergawy added a commit that referenced this pull request Dec 22, 2023
…75345)

Emits MLIR op corresponding to `!$omp target update` directive. So far,
only motion types: `to` and `from` are supported. Motion modifiers:
`present`, `mapper`, and `iterator` are not supported yet.

This is a follow up to #75047 & #75159, only the last commit is relevant
to this PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants