Skip to content

[Flang] [OpenMP] [MLIR] [Lowering] Add lowering support for IS_DEVICE_PTR and HAS_DEVICE_ADDR clauses on OMP TARGET directive. #88206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 11, 2024
Merged
29 changes: 29 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,20 @@ bool ClauseProcessor::processDepend(
});
}

bool ClauseProcessor::processHasDeviceAddr(
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
const {
return findRepeatableClause<omp::clause::HasDeviceAddr>(
[&](const omp::clause::HasDeviceAddr &devAddrClause,
const Fortran::parser::CharBlock &) {
addUseDeviceClause(converter, devAddrClause.v, operands, isDeviceTypes,
isDeviceLocs, isDeviceSymbols);
});
}

bool ClauseProcessor::processIf(
omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const {
Expand All @@ -771,6 +785,20 @@ bool ClauseProcessor::processIf(
return found;
}

bool ClauseProcessor::processIsDevicePtr(
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &isDeviceSymbols)
const {
return findRepeatableClause<omp::clause::IsDevicePtr>(
[&](const omp::clause::IsDevicePtr &devPtrClause,
const Fortran::parser::CharBlock &) {
addUseDeviceClause(converter, devPtrClause.v, operands, isDeviceTypes,
isDeviceLocs, isDeviceSymbols);
});
}

bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::Link>(
Expand Down Expand Up @@ -993,6 +1021,7 @@ bool ClauseProcessor::processUseDevicePtr(
useDeviceLocs, useDeviceSymbols);
});
}

} // namespace omp
} // namespace lower
} // namespace Fortran
12 changes: 12 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class ClauseProcessor {
bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
bool processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value &result) const;
bool
processHasDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&isDeviceSymbols) const;
bool processHint(mlir::IntegerAttr &result) const;
bool processMergeable(mlir::UnitAttr &result) const;
bool processNowait(mlir::UnitAttr &result) const;
Expand Down Expand Up @@ -104,6 +110,12 @@ class ClauseProcessor {
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const;
bool
processIsDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&isDeviceSymbols) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;

// This method is used to process a map clause.
Expand Down
22 changes: 17 additions & 5 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,11 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Type> mapSymTypes;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
llvm::SmallVector<mlir::Value> devicePtrOperands, deviceAddrOperands;
llvm::SmallVector<mlir::Type> devicePtrTypes, deviceAddrTypes;
llvm::SmallVector<mlir::Location> devicePtrLocs, deviceAddrLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> devicePtrSymbols,
deviceAddrSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(llvm::omp::Directive::OMPD_target, ifClauseOperand);
Expand All @@ -1303,11 +1308,15 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
&mapSymLocs, &mapSymbols);
cp.processIsDevicePtr(devicePtrOperands, devicePtrTypes, devicePtrLocs,
devicePtrSymbols);
cp.processHasDeviceAddr(deviceAddrOperands, deviceAddrTypes, deviceAddrLocs,
deviceAddrSymbols);

cp.processTODO<clause::Private, clause::Firstprivate, clause::IsDevicePtr,
clause::HasDeviceAddr, clause::Reduction, clause::InReduction,
clause::Allocate, clause::UsesAllocators, clause::Defaultmap>(
currentLocation, llvm::omp::Directive::OMPD_target);
cp.processTODO<clause::Private, clause::Firstprivate, clause::Reduction,
clause::InReduction, clause::Allocate, clause::UsesAllocators,
clause::Defaultmap>(currentLocation,
llvm::omp::Directive::OMPD_target);

// 5.8.1 Implicit Data-Mapping Attribute Rules
// The following code follows the implicit data-mapping rules to map all the
Expand Down Expand Up @@ -1400,7 +1409,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
dependTypeOperands),
dependOperands, nowaitAttr, mapOperands);
dependOperands, nowaitAttr, devicePtrOperands, deviceAddrOperands,
mapOperands);

genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
mapSymLocs, mapSymbols, currentLocation);
Expand Down Expand Up @@ -2059,6 +2069,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
!std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::IsDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::HasDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
Expand Down
43 changes: 42 additions & 1 deletion flang/test/Lower/OpenMP/FIR/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -506,4 +506,45 @@ subroutine omp_target_parallel_do
!CHECK: omp.terminator
!CHECK: }
!$omp end target parallel do
end subroutine omp_target_parallel_do
end subroutine omp_target_parallel_do

!===============================================================================
! Target `is_device_ptr` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
subroutine omp_target_is_device_ptr
use iso_c_binding, only : c_ptr, c_loc
!CHECK: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "a", uniq_name = "_QFomp_target_is_device_ptrEa"}
type(c_ptr) :: a
!CHECK: %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "b", fir.target, uniq_name = "_QFomp_target_is_device_ptrEb"}
integer, target :: b
!CHECK: %[[MAP_0:.*]] = omp.map.info var_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> {name = "a"}
!CHECK: %[[MAP_1:.*]] = omp.map.info var_ptr(%[[VAL_0:.*]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "b"}
!CHECK: %[[MAP_2:.*]] = omp.map.info var_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> {name = "a"}
!CHECK: omp.target is_device_ptr(%[[DEV_PTR:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) map_entries(%[[MAP_0:.*]] -> %[[ARG0:.*]], %[[MAP_1:.*]] -> %[[ARG1:.*]], %[[MAP_2:.*]] -> %[[ARG2:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<i32>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
!CHECK: ^bb0(%[[ARG0]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %[[ARG1]]: !fir.ref<i32>, %[[ARG2]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
!$omp target map(tofrom: a,b) is_device_ptr(a)
!CHECK: {{.*}} = fir.coordinate_of %[[VAL_0:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
a = c_loc(b)
!CHECK: omp.terminator
!$omp end target
!CHECK: }
end subroutine omp_target_is_device_ptr

!===============================================================================
! Target `has_device_addr` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_has_device_addr() {
subroutine omp_target_has_device_addr
!CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "a", uniq_name = "_QFomp_target_has_device_addrEa"}
integer, pointer :: a
!CHECK: omp.target has_device_addr(%[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) map_entries({{.*}} -> {{.*}}, {{.*}} -> {{.*}} : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
!$omp target has_device_addr(a)
!CHECK: {{.*}} = fir.load %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
a = 10
!CHECK: omp.terminator
!$omp end target
!CHECK: }
end subroutine omp_target_has_device_addr
16 changes: 11 additions & 5 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ struct GrainsizeClauseOps {
Value grainsizeVar;
};

struct HasDeviceAddrOps {
llvm::SmallVector<Value> hasDeviceAddrVars;
};
struct HintClauseOps {
IntegerAttr hintAttr;
};
Expand All @@ -94,6 +97,10 @@ struct InReductionClauseOps {
llvm::SmallVector<Attribute> inReductionDeclSymbols;
};

struct IsDevicePtrOps {
llvm::SmallVector<Value> isDevicePtrVars;
};

struct LinearClauseOps {
llvm::SmallVector<Value> linearVars, linearStepVars;
};
Expand Down Expand Up @@ -251,13 +258,12 @@ using SimdLoopClauseOps =
using SingleClauseOps = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
NowaitClauseOps, PrivateClauseOps>;

// TODO `defaultmap`, `has_device_addr`, `is_device_ptr`, `uses_allocators`
// clauses.
// TODO `defaultmap`, `uses_allocators` clauses.
using TargetClauseOps =
detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
IfClauseOps, InReductionClauseOps, MapClauseOps,
NowaitClauseOps, PrivateClauseOps, ReductionClauseOps,
ThreadLimitClauseOps>;
HasDeviceAddrOps, IfClauseOps, InReductionClauseOps,
IsDevicePtrOps, MapClauseOps, NowaitClauseOps,
PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;

using TargetDataClauseOps = detail::Clauses<DeviceClauseOps, IfClauseOps,
MapClauseOps, UseDeviceClauseOps>;
Expand Down
18 changes: 15 additions & 3 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1678,14 +1678,23 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac

The optional $thread_limit specifies the limit on the number of threads

The optional $nowait elliminates the implicit barrier so the parent task can make progress
The optional $nowait eliminates the implicit barrier so the parent task can make progress
even if the target task is not yet completed.

The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

TODO: is_device_ptr, defaultmap, in_reduction
The optional $is_device_ptr indicates list items are device pointers.

The optional $has_device_addr indicates that list items already have device
addresses, so they may be directly accessed from the target device. This
includes array sections.

The optional $map_operands maps data from the task’s environment to the
device environment.

TODO: defaultmap, in_reduction

}];

Expand All @@ -1695,8 +1704,9 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
Variadic<OpenMP_PointerLikeType>:$has_device_addr,
Variadic<AnyType>:$map_operands);

let regions = (region AnyRegion:$region);

let builders = [
Expand All @@ -1708,6 +1718,8 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
| `device` `(` $device `:` type($device) `)`
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `nowait` $nowait
| `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
| `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) $region attr-dict
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1258,10 +1258,11 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
// inReductionDeclSymbols, privateVars, privatizers, reductionVars,
// reductionByRefAttr, reductionDeclSymbols.
TargetOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
clauses.threadLimitVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs),
clauses.dependVars, clauses.nowaitAttr, clauses.mapVars);
TargetOp::build(
builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
clauses.mapVars);
}

LogicalResult TargetOp::verify() {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1809,7 +1809,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0>} : (memref<i32>) -> ()
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}

Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -510,22 +510,22 @@ return


// CHECK-LABEL: omp_target
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {
func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %device_ptr: memref<i32>, %device_addr: memref<?xi32>, %map1: memref<?xi32>, %map2: memref<?xi32>) -> () {

// Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
// CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0>} : ( i1, si32, i32 ) -> ()
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()

// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
// CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
// CHECK: omp.target map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
// CHECK: omp.target is_device_ptr(%[[VAL_4:.*]] : memref<i32>) has_device_addr(%[[VAL_5:.*]] : memref<?xi32>) map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
%mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
%mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) {
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) is_device_ptr(%device_ptr : memref<i32>) has_device_addr(%device_addr : memref<?xi32>) {
^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>):
omp.terminator
}
Expand Down