Skip to content

[Flang] [OpenMP] [Semantics] [MLIR] [Lowering] Add lowering support for IS_DEVICE_PTR and HAS_DEVICE_ADDR clauses on OMP TARGET directive. #74187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,20 @@ bool ClauseProcessor::processDepend(
});
}

bool ClauseProcessor::processHasDeviceAddr(
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
const {
return findRepeatableClause<omp::clause::HasDeviceAddr>(
[&](const omp::clause::HasDeviceAddr &devAddrClause,
const Fortran::parser::CharBlock &) {
addUseDeviceClause(converter, devAddrClause.v, operands, isDeviceTypes,
isDeviceLocs, isDeviceSymbols);
});
}

bool ClauseProcessor::processIf(
omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const {
Expand All @@ -771,6 +785,20 @@ bool ClauseProcessor::processIf(
return found;
}

bool ClauseProcessor::processIsDevicePtr(
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
const {
return findRepeatableClause<omp::clause::IsDevicePtr>(
[&](const omp::clause::IsDevicePtr &devPtrClause,
const Fortran::parser::CharBlock &) {
addUseDeviceClause(converter, devPtrClause.v, operands, isDeviceTypes,
isDeviceLocs, isDeviceSymbols);
});
}

bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::Link>(
Expand Down Expand Up @@ -993,6 +1021,7 @@ bool ClauseProcessor::processUseDevicePtr(
useDeviceLocs, useDeviceSymbols);
});
}

} // namespace omp
} // namespace lower
} // namespace Fortran
12 changes: 12 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class ClauseProcessor {
bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
bool processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool
processHasDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&isDeviceSymbols) const;
bool processHint(mlir::IntegerAttr &result) const;
bool processMergeable(mlir::UnitAttr &result) const;
bool processNowait(mlir::UnitAttr &result) const;
Expand Down Expand Up @@ -104,6 +110,12 @@ class ClauseProcessor {
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const;
bool
processIsDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&isDeviceSymbols) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;

// This method is used to process a map clause.
Expand Down
22 changes: 17 additions & 5 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,11 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Type> mapSymTypes;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
llvm::SmallVector<mlir::Value> devicePtrOperands, deviceAddrOperands;
llvm::SmallVector<mlir::Type> devicePtrTypes, deviceAddrTypes;
llvm::SmallVector<mlir::Location> devicePtrLocs, deviceAddrLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> devicePtrSymbols,
deviceAddrSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(llvm::omp::Directive::OMPD_target, ifClauseOperand);
Expand All @@ -1303,11 +1308,15 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
&mapSymLocs, &mapSymbols);
cp.processIsDevicePtr(devicePtrOperands, devicePtrTypes, devicePtrLocs,
devicePtrSymbols);
cp.processHasDeviceAddr(deviceAddrOperands, deviceAddrTypes, deviceAddrLocs,
deviceAddrSymbols);

cp.processTODO<clause::Private, clause::Firstprivate, clause::IsDevicePtr,
clause::HasDeviceAddr, clause::Reduction, clause::InReduction,
clause::Allocate, clause::UsesAllocators, clause::Defaultmap>(
currentLocation, llvm::omp::Directive::OMPD_target);
cp.processTODO<clause::Private, clause::Firstprivate, clause::Reduction,
clause::InReduction, clause::Allocate, clause::UsesAllocators,
clause::Defaultmap>(currentLocation,
llvm::omp::Directive::OMPD_target);

// 5.8.1 Implicit Data-Mapping Attribute Rules
// The following code follows the implicit data-mapping rules to map all the
Expand Down Expand Up @@ -1400,7 +1409,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
dependTypeOperands),
dependOperands, nowaitAttr, mapOperands);
dependOperands, nowaitAttr, devicePtrOperands, deviceAddrOperands,
mapOperands);

genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
mapSymLocs, mapSymbols, currentLocation);
Expand Down Expand Up @@ -2059,6 +2069,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
!std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::IsDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::HasDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
Expand Down
43 changes: 42 additions & 1 deletion flang/test/Lower/OpenMP/FIR/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -506,4 +506,45 @@ subroutine omp_target_parallel_do
!CHECK: omp.terminator
!CHECK: }
!$omp end target parallel do
end subroutine omp_target_parallel_do
end subroutine omp_target_parallel_do

!===============================================================================
! Target `is_device_ptr` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
subroutine omp_target_is_device_ptr
use iso_c_binding, only : c_ptr, c_loc
!CHECK: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "a", uniq_name = "_QFomp_target_is_device_ptrEa"}
type(c_ptr) :: a
!CHECK: %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "b", fir.target, uniq_name = "_QFomp_target_is_device_ptrEb"}
integer, target :: b
!CHECK: %[[MAP_0:.*]] = omp.map.info var_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> {name = "a"}
!CHECK: %[[MAP_1:.*]] = omp.map.info var_ptr(%[[VAL_0:.*]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "b"}
!CHECK: %[[MAP_2:.*]] = omp.map.info var_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> {name = "a"}
!CHECK: omp.target is_device_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) map_entries(%[[MAP_0:.*]] -> %[[ARG0:.*]], %[[MAP_1:.*]] -> %[[ARG1:.*]], %[[MAP_2:.*]] -> %[[ARG2:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<i32>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
!CHECK: ^bb0(%[[ARG0]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %[[ARG1]]: !fir.ref<i32>, %[[ARG2]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
!$omp target map(tofrom: a,b) is_device_ptr(a)
!CHECK: {{.*}} = fir.coordinate_of %[[VAL_0:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
a = c_loc(b)
!CHECK: omp.terminator
!$omp end target
!CHECK: }
end subroutine omp_target_is_device_ptr

!===============================================================================
! Target `has_device_addr` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_has_device_addr() {
subroutine omp_target_has_device_addr
!CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "a", uniq_name = "_QFomp_target_has_device_addrEa"}
integer, pointer :: a
!CHECK: omp.target has_device_addr(%[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) map_entries({{.*}} -> {{.*}}, {{.*}} -> {{.*}} : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
!$omp target has_device_addr(a)
!CHECK: {{.*}} = fir.load %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
a = 10
!CHECK: omp.terminator
!$omp end target
!CHECK: }
end subroutine omp_target_has_device_addr
18 changes: 15 additions & 3 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1628,14 +1628,23 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac

The optional $thread_limit specifies the limit on the number of threads

The optional $nowait elliminates the implicit barrier so the parent task can make progress
The optional $nowait eliminates the implicit barrier so the parent task can make progress
even if the target task is not yet completed.

The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

TODO: is_device_ptr, defaultmap, in_reduction
The optional $is_device_ptr indicates list items are device pointers.

The optional $has_device_addr indicates that list items already have device
addresses, so they may be directly accessed from the target device. This
includes array sections.

The optional $map_operands maps data from the task’s environment to the
device environment.

TODO: defaultmap, in_reduction

}];

Expand All @@ -1645,15 +1654,18 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
Variadic<OpenMP_PointerLikeType>:$has_device_addr,
Variadic<AnyType>:$map_operands);

let regions = (region AnyRegion:$region);

let assemblyFormat = [{
oilist( `if` `(` $if_expr `)`
| `device` `(` $device `:` type($device) `)`
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `nowait` $nowait
| `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
| `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) $region attr-dict
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1809,7 +1809,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0>} : (memref<i32>) -> ()
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}

Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -510,22 +510,22 @@ return


// CHECK-LABEL: omp_target
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %device_ptr: memref<i32>, %device_addr: memref<?xi32>, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {

// Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
// CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0>} : ( i1, si32, i32 ) -> ()
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()

// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
// CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
// CHECK: omp.target map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
// CHECK: omp.target is_device_ptr(%[[VAL_4:.*]] : memref<i32>) has_device_addr(%[[VAL_5:.*]] : memref<?xi32>) map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
%mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
%mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) {
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) is_device_ptr(%device_ptr : memref<i32>) has_device_addr(%device_addr : memref<?xi32>) {
^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>):
omp.terminator
}
Expand Down