Skip to content

Commit d879ead

Browse files
committed
[Flang] [OpenMP] [Semantics] [MLIR] [Lowering] Add lowering support for IS_DEVICE_PTR and
HAS_DEVICE_ADDR clauses on OMP TARGET directive.
1 parent aaf9164 commit d879ead

File tree

5 files changed

+112
-12
lines changed

5 files changed

+112
-12
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,18 @@ class ClauseProcessor {
641641
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
642642
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
643643
&useDeviceSymbols) const;
644+
bool
645+
processIsDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
646+
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
647+
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
648+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
649+
&isDeviceSymbols) const;
650+
bool
651+
processHasDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
652+
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
653+
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
654+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
655+
&isDeviceSymbols) const;
644656

645657
template <typename T>
646658
bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
@@ -2072,6 +2084,34 @@ bool ClauseProcessor::processMotionClauses(
20722084

20732085
mapOperands.push_back(mapOp);
20742086
}
2087+
});
2088+
}
2089+
2090+
bool ClauseProcessor::processIsDevicePtr(
2091+
llvm::SmallVectorImpl<mlir::Value> &operands,
2092+
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
2093+
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
2094+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
2095+
const {
2096+
return findRepeatableClause<ClauseTy::IsDevicePtr>(
2097+
[&](const ClauseTy::IsDevicePtr *devPtrClause,
2098+
const Fortran::parser::CharBlock &) {
2099+
addUseDeviceClause(converter, devPtrClause->v, operands, isDeviceTypes,
2100+
isDeviceLocs, isDeviceSymbols);
2101+
});
2102+
}
2103+
2104+
bool ClauseProcessor::processHasDeviceAddr(
2105+
llvm::SmallVectorImpl<mlir::Value> &operands,
2106+
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
2107+
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
2108+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
2109+
const {
2110+
return findRepeatableClause<ClauseTy::HasDeviceAddr>(
2111+
[&](const ClauseTy::HasDeviceAddr *devAddrClause,
2112+
const Fortran::parser::CharBlock &) {
2113+
addUseDeviceClause(converter, devAddrClause->v, operands, isDeviceTypes,
2114+
isDeviceLocs, isDeviceSymbols);
20752115
});
20762116
}
20772117

@@ -2999,6 +3039,10 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
29993039
llvm::SmallVector<mlir::Type> mapSymTypes;
30003040
llvm::SmallVector<mlir::Location> mapSymLocs;
30013041
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
3042+
llvm::SmallVector<mlir::Value> devicePtrOperands, deviceAddrOperands;
3043+
llvm::SmallVector<mlir::Type> useDeviceTypes;
3044+
llvm::SmallVector<mlir::Location> useDeviceLocs;
3045+
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
30023046

30033047
ClauseProcessor cp(converter, semaCtx, clauseList);
30043048
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
@@ -3008,11 +3052,13 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
30083052
cp.processNowait(nowaitAttr);
30093053
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
30103054
&mapSymLocs, &mapSymbols);
3055+
cp.processIsDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
3056+
useDeviceSymbols);
3057+
cp.processHasDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
3058+
useDeviceSymbols);
30113059
cp.processTODO<Fortran::parser::OmpClause::Private,
30123060
Fortran::parser::OmpClause::Depend,
30133061
Fortran::parser::OmpClause::Firstprivate,
3014-
Fortran::parser::OmpClause::IsDevicePtr,
3015-
Fortran::parser::OmpClause::HasDeviceAddr,
30163062
Fortran::parser::OmpClause::Reduction,
30173063
Fortran::parser::OmpClause::InReduction,
30183064
Fortran::parser::OmpClause::Allocate,
@@ -3093,7 +3139,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
30933139

30943140
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
30953141
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
3096-
nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
3142+
nullptr, mlir::ValueRange(), nowaitAttr, devicePtrOperands,
3143+
deviceAddrOperands, mapOperands);
30973144

30983145
genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
30993146
mapSymLocs, mapSymbols, currentLocation);
@@ -3700,6 +3747,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
37003747
!std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
37013748
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
37023749
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
3750+
!std::get_if<Fortran::parser::OmpClause::IsDevicePtr>(&clause.u) &&
3751+
!std::get_if<Fortran::parser::OmpClause::HasDeviceAddr>(&clause.u) &&
37033752
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
37043753
!std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
37053754
TODO(clauseLocation, "OpenMP Block construct clause");

flang/test/Lower/OpenMP/FIR/target.f90

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,4 +506,43 @@ subroutine omp_target_parallel_do
506506
!CHECK: omp.terminator
507507
!CHECK: }
508508
!$omp end target parallel do
509-
end subroutine omp_target_parallel_do
509+
end subroutine omp_target_parallel_do
510+
511+
!===============================================================================
512+
! Target `is_device_ptr` clause
513+
!===============================================================================
514+
515+
!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
516+
subroutine omp_target_is_device_ptr
517+
use iso_c_binding, only : c_ptr, c_loc
518+
!CHECK: %[[DEV_PTR:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "a", uniq_name = "_QFomp_target_is_device_ptrEa"}
519+
type(c_ptr) :: a
520+
!CHECK %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "b", fir.target, uniq_name = "_QFomp_target_is_device_ptrEb"}
521+
integer, target :: b
522+
!CHECK: %[[MAP_0:.*]] = omp.map_info var_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> {name = "a"}
523+
!CHECK: %[[MAP_1:.*]] = omp.map_info var_ptr(%[[VAL_0:.*]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "b"}
524+
!CHECK: omp.target is_device_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) map_entries(%[[MAP_0:.*]], %[[MAP_1:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<i32>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
525+
!$omp target map(tofrom: a,b) is_device_ptr(a)
526+
!CHECK: {{.*}} = fir.coordinate_of %[[DEV_PTR:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
527+
a = c_loc(b)
528+
!CHECK: omp.terminator
529+
!$omp end target
530+
!CHECK: }
531+
end subroutine omp_target_is_device_ptr
532+
533+
!===============================================================================
534+
! Target `has_device_addr` clause
535+
!===============================================================================
536+
537+
!CHECK-LABEL: func.func @_QPomp_target_has_device_addr() {
538+
subroutine omp_target_has_device_addr
539+
integer, pointer :: a
540+
!CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "a", uniq_name = "_QFomp_target_has_device_addrEa"}
541+
!CHECK: omp.target has_device_addr(%[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) map_entries({{.*}} -> {{.*}}, {{.*}} -> {{.*}} : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
542+
!$omp target has_device_addr(a)
543+
!CHECK: {{.*}} = fir.load %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
544+
a = 10
545+
!CHECK: omp.terminator
546+
!$omp end target
547+
!CHECK: }
548+
end subroutine omp_target_has_device_addr

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,14 +1609,23 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
16091609

16101610
The optional $thread_limit specifies the limit on the number of threads
16111611

1612-
The optional $nowait elliminates the implicit barrier so the parent task can make progress
1612+
The optional $nowait eliminates the implicit barrier so the parent task can make progress
16131613
even if the target task is not yet completed.
16141614

16151615
The `depends` and `depend_vars` arguments are variadic lists of values
16161616
that specify the dependencies of this particular target task in relation to
16171617
other tasks.
16181618

1619-
TODO: is_device_ptr, defaultmap, in_reduction
1619+
The optional $is_device_ptr indicates list items are device pointers
1620+
1621+
The optional $has_device_addr indicates that list items already have device
1622+
addresses, so may be directly accessed from target device. May include array
1623+
sections.
1624+
1625+
The optional $map_operands maps data from the task’s environment to the
1626+
device environment.
1627+
1628+
TODO: defaultmap, in_reduction
16201629

16211630
}];
16221631

@@ -1626,15 +1635,18 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
16261635
OptionalAttr<TaskDependArrayAttr>:$depends,
16271636
Variadic<OpenMP_PointerLikeType>:$depend_vars,
16281637
UnitAttr:$nowait,
1638+
Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
1639+
Variadic<OpenMP_PointerLikeType>:$has_device_addr,
16291640
Variadic<AnyType>:$map_operands);
1630-
16311641
let regions = (region AnyRegion:$region);
16321642

16331643
let assemblyFormat = [{
16341644
oilist( `if` `(` $if_expr `)`
16351645
| `device` `(` $device `:` type($device) `)`
16361646
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
16371647
| `nowait` $nowait
1648+
| `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
1649+
| `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
16381650
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
16391651
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
16401652
) $region attr-dict

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1790,7 +1790,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
17901790
// expected-error @below {{op expected as many depend values as depend variables}}
17911791
"omp.target"(%data_var) ({
17921792
"omp.terminator"() : () -> ()
1793-
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0>} : (memref<i32>) -> ()
1793+
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
17941794
"func.return"() : () -> ()
17951795
}
17961796

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,22 +510,22 @@ return
510510

511511

512512
// CHECK-LABEL: omp_target
513-
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
513+
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %device_ptr: memref<i32>, %device_addr: memref<?xi32>, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
514514

515515
// Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
516516
// CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
517517
"omp.target"(%if_cond, %device, %num_threads) ({
518518
// CHECK: omp.terminator
519519
omp.terminator
520-
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0>} : ( i1, si32, i32 ) -> ()
520+
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()
521521

522522
// Test with optional map clause.
523523
// CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
524524
// CHECK: %[[MAP_B:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
525-
// CHECK: omp.target map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
525+
// CHECK: omp.target is_device_ptr(%[[VAL_4:.*]] : memref<i32>) has_device_addr(%[[VAL_5:.*]] : memref<?xi32>) map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
526526
%mapv1 = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
527527
%mapv2 = omp.map_info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
528-
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) {
528+
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) is_device_ptr(%device_ptr : memref<i32>) has_device_addr(%device_addr : memref<?xi32>) {
529529
^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>):
530530
omp.terminator
531531
}

0 commit comments

Comments
 (0)