Skip to content

Commit f2089f8

Browse files
committed
[Flang] [OpenMP] [Semantics] [MLIR] [Lowering] Add lowering support for IS_DEVICE_PTR and
HAS_DEVICE_ADDR clauses on OMP TARGET directive.
1 parent decf027 commit f2089f8

File tree

4 files changed

+110
-11
lines changed

4 files changed

+110
-11
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,18 @@ class ClauseProcessor {
604604
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
605605
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
606606
&useDeviceSymbols) const;
607+
bool
608+
processIsDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
609+
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
610+
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
611+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
612+
&isDeviceSymbols) const;
613+
bool
614+
processHasDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
615+
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
616+
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
617+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
618+
&isDeviceSymbols) const;
607619

608620
// Call this method for these clauses that should be supported but are not
609621
// implemented yet. It triggers a compilation error if any of the given
@@ -1890,6 +1902,34 @@ bool ClauseProcessor::processUseDevicePtr(
18901902
});
18911903
}
18921904

1905+
bool ClauseProcessor::processIsDevicePtr(
1906+
llvm::SmallVectorImpl<mlir::Value> &operands,
1907+
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
1908+
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
1909+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
1910+
const {
1911+
return findRepeatableClause<ClauseTy::IsDevicePtr>(
1912+
[&](const ClauseTy::IsDevicePtr *devPtrClause,
1913+
const Fortran::parser::CharBlock &) {
1914+
addUseDeviceClause(converter, devPtrClause->v, operands, isDeviceTypes,
1915+
isDeviceLocs, isDeviceSymbols);
1916+
});
1917+
}
1918+
1919+
bool ClauseProcessor::processHasDeviceAddr(
1920+
llvm::SmallVectorImpl<mlir::Value> &operands,
1921+
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
1922+
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
1923+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
1924+
const {
1925+
return findRepeatableClause<ClauseTy::HasDeviceAddr>(
1926+
[&](const ClauseTy::HasDeviceAddr *devAddrClause,
1927+
const Fortran::parser::CharBlock &) {
1928+
addUseDeviceClause(converter, devAddrClause->v, operands, isDeviceTypes,
1929+
isDeviceLocs, isDeviceSymbols);
1930+
});
1931+
}
1932+
18931933
template <typename... Ts>
18941934
void ClauseProcessor::processTODO(mlir::Location currentLocation,
18951935
llvm::omp::Directive directive) const {
@@ -2617,6 +2657,10 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
26172657
llvm::SmallVector<mlir::Type> mapSymTypes;
26182658
llvm::SmallVector<mlir::Location> mapSymLocs;
26192659
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
2660+
llvm::SmallVector<mlir::Value> devicePtrOperands, deviceAddrOperands;
2661+
llvm::SmallVector<mlir::Type> useDeviceTypes;
2662+
llvm::SmallVector<mlir::Location> useDeviceLocs;
2663+
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
26202664

26212665
ClauseProcessor cp(converter, clauseList);
26222666
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
@@ -2626,11 +2670,13 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
26262670
cp.processNowait(nowaitAttr);
26272671
cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
26282672
mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols);
2673+
cp.processIsDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
2674+
useDeviceSymbols);
2675+
cp.processHasDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
2676+
useDeviceSymbols);
26292677
cp.processTODO<Fortran::parser::OmpClause::Private,
26302678
Fortran::parser::OmpClause::Depend,
26312679
Fortran::parser::OmpClause::Firstprivate,
2632-
Fortran::parser::OmpClause::IsDevicePtr,
2633-
Fortran::parser::OmpClause::HasDeviceAddr,
26342680
Fortran::parser::OmpClause::Reduction,
26352681
Fortran::parser::OmpClause::InReduction,
26362682
Fortran::parser::OmpClause::Allocate,
@@ -2705,7 +2751,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
27052751

27062752
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
27072753
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
2708-
nowaitAttr, mapOperands);
2754+
nowaitAttr, devicePtrOperands, deviceAddrOperands, mapOperands);
27092755

27102756
genBodyOfTargetOp(converter, eval, targetOp, mapSymTypes, mapSymLocs,
27112757
mapSymbols, currentLocation);
@@ -3101,6 +3147,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
31013147
!std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
31023148
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
31033149
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
3150+
!std::get_if<Fortran::parser::OmpClause::IsDevicePtr>(&clause.u) &&
3151+
!std::get_if<Fortran::parser::OmpClause::HasDeviceAddr>(&clause.u) &&
31043152
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
31053153
!std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
31063154
TODO(clauseLocation, "OpenMP Block construct clause");

flang/test/Lower/OpenMP/FIR/target.f90

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,4 +411,43 @@ subroutine omp_target_parallel_do
411411
!CHECK: omp.terminator
412412
!CHECK: }
413413
!$omp end target parallel do
414-
end subroutine omp_target_parallel_do
414+
end subroutine omp_target_parallel_do
415+
416+
!===============================================================================
417+
! Target `is_device_ptr` clause
418+
!===============================================================================
419+
420+
!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
421+
subroutine omp_target_is_device_ptr
422+
use iso_c_binding, only : c_ptr, c_loc
423+
!CHECK: %[[DEV_PTR:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "a", uniq_name = "_QFomp_target_is_device_ptrEa"}
424+
type(c_ptr) :: a
425+
!CHECK %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "b", fir.target, uniq_name = "_QFomp_target_is_device_ptrEb"}
426+
integer, target :: b
427+
!CHECK: %[[MAP_0:.*]] = omp.map_info var_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> {name = "a"}
428+
!CHECK: %[[MAP_1:.*]] = omp.map_info var_ptr(%[[VAL_0:.*]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "b"}
429+
!CHECK: omp.target is_device_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) map_entries(%[[MAP_0:.*]], %[[MAP_1:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<i32>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
430+
!$omp target map(tofrom: a,b) is_device_ptr(a)
431+
!CHECK: {{.*}} = fir.coordinate_of %[[DEV_PTR:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
432+
a = c_loc(b)
433+
!CHECK: omp.terminator
434+
!$omp end target
435+
!CHECK: }
436+
end subroutine omp_target_is_device_ptr
437+
438+
!===============================================================================
439+
! Target `has_device_addr` clause
440+
!===============================================================================
441+
442+
!CHECK-LABEL: func.func @_QPomp_target_has_device_addr() {
443+
subroutine omp_target_has_device_addr
444+
integer, pointer :: a
445+
!CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "a", uniq_name = "_QFomp_target_has_device_addrEa"}
446+
!CHECK: omp.target has_device_addr(%[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
447+
!$omp target has_device_addr(a)
448+
!CHECK: {{.*}} = fir.load %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
449+
a = 10
450+
!CHECK: omp.terminator
451+
!$omp end target
452+
!CHECK: }
453+
end subroutine omp_target_has_device_addr

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,26 +1389,38 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
13891389

13901390
The optional $thread_limit specifies the limit on the number of threads
13911391

1392-
The optional $nowait elliminates the implicit barrier so the parent task can make progress
1392+
The optional $nowait eliminates the implicit barrier so the parent task can make progress
13931393
even if the target task is not yet completed.
13941394

1395-
TODO: is_device_ptr, depend, defaultmap, in_reduction
1395+
The optional $is_device_ptr indicates list items are device pointers
1396+
1397+
The optional $has_device_addr indicates that list items already have device
1398+
addresses, so may be directly accessed from target device. May include array
1399+
sections.
1400+
1401+
The optional $map_operands maps data from the task’s environment to the
1402+
device environment.
1403+
1404+
TODO: depend, defaultmap, in_reduction
13961405

13971406
}];
13981407

13991408
let arguments = (ins Optional<I1>:$if_expr,
14001409
Optional<AnyInteger>:$device,
14011410
Optional<AnyInteger>:$thread_limit,
14021411
UnitAttr:$nowait,
1412+
Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
1413+
Variadic<OpenMP_PointerLikeType>:$has_device_addr,
14031414
Variadic<AnyType>:$map_operands);
1404-
14051415
let regions = (region AnyRegion:$region);
14061416

14071417
let assemblyFormat = [{
14081418
oilist( `if` `(` $if_expr `)`
14091419
| `device` `(` $device `:` type($device) `)`
14101420
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
14111421
| `nowait` $nowait
1422+
| `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
1423+
| `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
14121424
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
14131425
) $region attr-dict
14141426
}];

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,22 +480,22 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
480480
}
481481

482482
// CHECK-LABEL: omp_target
483-
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
483+
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %device_ptr: memref<i32>, %device_addr: memref<?xi32>, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
484484

485485
// Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
486486
// CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
487487
"omp.target"(%if_cond, %device, %num_threads) ({
488488
// CHECK: omp.terminator
489489
omp.terminator
490-
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
490+
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : ( i1, si32, i32 ) -> ()
491491

492492
// Test with optional map clause.
493493
// CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
494494
// CHECK: %[[MAP_B:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
495-
// CHECK: omp.target map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
495+
// CHECK: omp.target is_device_ptr(%[[VAL_4:.*]] : memref<i32>) has_device_addr(%[[VAL_5:.*]] : memref<?xi32>) map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
496496
%mapv1 = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
497497
%mapv2 = omp.map_info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
498-
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) {
498+
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) is_device_ptr(%device_ptr : memref<i32>) has_device_addr(%device_addr : memref<?xi32>) {
499499
^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>):
500500
omp.terminator
501501
}

0 commit comments

Comments
 (0)