Skip to content

Commit d777504

Browse files
authored
[MLIR][OpenMP][Offload] Lower target update op to DeviceRT (#75159)
Adds support for lowring `UpdateDataOp` to the DeviceRT. This reuses the existing utils used by other device directive.
1 parent d01be3c commit d777504

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,6 +1915,23 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
19151915
mapOperands = exitDataOp.getMapOperands();
19161916
return success();
19171917
})
1918+
.Case([&](omp::UpdateDataOp updateDataOp) {
1919+
if (updateDataOp.getNowait())
1920+
return failure();
1921+
1922+
if (auto ifExprVar = updateDataOp.getIfExpr())
1923+
ifCond = moduleTranslation.lookupValue(ifExprVar);
1924+
1925+
if (auto devId = updateDataOp.getDevice())
1926+
if (auto constOp =
1927+
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
1928+
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
1929+
deviceID = intAttr.getInt();
1930+
1931+
RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
1932+
mapOperands = updateDataOp.getMotionOperands();
1933+
return success();
1934+
})
19181935
.Default([&](Operation *op) {
19191936
return op->emitError("unsupported OpenMP operation: ")
19201937
<< op->getName();
@@ -2748,9 +2765,10 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
27482765
.Case([&](omp::ThreadprivateOp) {
27492766
return convertOmpThreadprivate(*op, builder, moduleTranslation);
27502767
})
2751-
.Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp>([&](auto op) {
2752-
return convertOmpTargetData(op, builder, moduleTranslation);
2753-
})
2768+
.Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp, omp::UpdateDataOp>(
2769+
[&](auto op) {
2770+
return convertOmpTargetData(op, builder, moduleTranslation);
2771+
})
27542772
.Case([&](omp::TargetOp) {
27552773
return convertOmpTarget(*op, builder, moduleTranslation);
27562774
})

mlir/test/Target/LLVMIR/omptarget-llvm.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,41 @@ llvm.func @_QPopenmp_target_use_dev_both() {
441441
// CHECK: ret void
442442

443443
// -----
444+
445+
llvm.func @_QPopenmp_target_data_update() {
446+
%0 = llvm.mlir.constant(1 : i64) : i64
447+
%1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
448+
%2 = omp.map_info var_ptr(%1 : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
449+
omp.target_data map_entries(%2 : !llvm.ptr) {
450+
%3 = llvm.mlir.constant(99 : i32) : i32
451+
llvm.store %3, %1 : i32, !llvm.ptr
452+
omp.terminator
453+
}
454+
455+
omp.target_update_data motion_entries(%2 : !llvm.ptr)
456+
457+
llvm.return
458+
}
459+
460+
// CHECK-LABEL: define void @_QPopenmp_target_data_update
461+
462+
// CHECK-DAG: %[[OFFLOAD_BASEPTRS:.*]] = alloca [1 x ptr], align 8
463+
// CHECK-DAG: %[[OFFLOAD_PTRS:.*]] = alloca [1 x ptr], align 8
464+
// CHECK-DAG: %[[INT_ALLOCA:.*]] = alloca i32, i64 1, align 4
465+
// CHECK-DAG: %[[OFFLOAD_MAPPERS:.*]] = alloca [1 x ptr], align 8
466+
467+
// CHECK: call void @__tgt_target_data_begin_mapper
468+
// CHECK: store i32 99, ptr %[[INT_ALLOCA]], align 4
469+
// CHECK: call void @__tgt_target_data_end_mapper
470+
471+
// CHECK: %[[BASEPTRS_VAL:.*]] = getelementptr inbounds [1 x ptr], ptr %[[OFFLOAD_BASEPTRS]], i32 0, i32 0
472+
// CHECK: store ptr %[[INT_ALLOCA]], ptr %[[BASEPTRS_VAL]], align 8
473+
// CHECK: %[[PTRS_VAL:.*]] = getelementptr inbounds [1 x ptr], ptr %[[OFFLOAD_PTRS]], i32 0, i32 0
474+
// CHECK: store ptr %[[INT_ALLOCA]], ptr %[[PTRS_VAL]], align 8
475+
// CHECK: %[[MAPPERS_VAL:.*]] = getelementptr inbounds [1 x ptr], ptr %[[OFFLOAD_MAPPERS]], i64 0, i64 0
476+
// CHECK: store ptr null, ptr %[[MAPPERS_VAL]], align 8
477+
// CHECK: %[[BASEPTRS_VAL_2:.*]] = getelementptr inbounds [1 x ptr], ptr %[[OFFLOAD_BASEPTRS]], i32 0, i32 0
478+
// CHECK: %[[PTRS_VAL_2:.*]] = getelementptr inbounds [1 x ptr], ptr %[[OFFLOAD_PTRS]], i32 0, i32 0
479+
// CHECK: call void @__tgt_target_data_update_mapper(ptr @2, i64 -1, i32 1, ptr %[[BASEPTRS_VAL_2]], ptr %[[PTRS_VAL_2]], ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr null)
480+
481+
// CHECK: ret void

0 commit comments

Comments
 (0)