Skip to content

[mlir][openmp][flang] - MLIR support for the depend clause (omp dialect) in offloading directives #80626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2662,7 +2662,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand;
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Value> mapOperands;
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;

Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
llvm::omp::Directive directive;
Expand Down Expand Up @@ -2697,12 +2698,15 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
mapOperands);
}
cp.processDepend(dependTypeOperands, dependOperands);

cp.processTODO<Fortran::parser::OmpClause::Depend>(currentLocation,
directive);

return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
deviceOperand, nowaitAttr, mapOperands);
return firOpBuilder.create<OpTy>(
currentLocation, ifClauseOperand, deviceOperand,
dependTypeOperands.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
dependTypeOperands),
dependOperands, nowaitAttr, mapOperands);
}

// This functions creates a block for the body of the targetOp's region. It adds
Expand Down Expand Up @@ -2867,7 +2871,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Value> mapOperands;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
llvm::SmallVector<mlir::Type> mapSymTypes;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
Expand All @@ -2880,8 +2885,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols);
cp.processDepend(dependTypeOperands, dependOperands);
cp.processTODO<Fortran::parser::OmpClause::Private,
Fortran::parser::OmpClause::Depend,
Fortran::parser::OmpClause::Firstprivate,
Fortran::parser::OmpClause::IsDevicePtr,
Fortran::parser::OmpClause::HasDeviceAddr,
Expand All @@ -2891,7 +2896,6 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::parser::OmpClause::UsesAllocators,
Fortran::parser::OmpClause::Defaultmap>(
currentLocation, llvm::omp::Directive::OMPD_target);

// 5.8.1 Implicit Data-Mapping Attribute Rules
// The following code follows the implicit data-mapping rules to map all the
// symbols used inside the region that have not been explicitly mapped using
Expand Down Expand Up @@ -2962,7 +2966,11 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,

auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
nowaitAttr, mapOperands);
dependTypeOperands.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
dependTypeOperands),
dependOperands, nowaitAttr, mapOperands);

genBodyOfTargetOp(converter, eval, genNested, targetOp, mapSymTypes,
mapSymLocs, mapSymbols, currentLocation);
Expand Down
8 changes: 8 additions & 0 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2815,6 +2815,14 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Device &x) {

void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
CheckAllowed(llvm::omp::Clause::OMPC_depend);
if ((std::holds_alternative<parser::OmpDependClause::Source>(x.v.u) ||
std::holds_alternative<parser::OmpDependClause::Sink>(x.v.u)) &&
GetContext().directive != llvm::omp::OMPD_ordered) {
context_.Say(GetContext().clauseSource,
"DEPEND(SOURCE) or DEPEND(SINK : vec) can be used only with the ordered"
" directive. Used here in the %s construct."_err_en_US,
parser::ToUpperCaseLetters(getDirectiveName(GetContext().directive)));
}
if (const auto *inOut{std::get_if<parser::OmpDependClause::InOut>(&x.v.u)}) {
const auto &designators{std::get<std::list<parser::Designator>>(inOut->t)};
for (const auto &ele : designators) {
Expand Down
85 changes: 85 additions & 0 deletions flang/test/Lower/OpenMP/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ subroutine omp_target_enter_simple
return
end subroutine omp_target_enter_simple

!===============================================================================
! Target_Enter `depend` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_enter_depend() {
subroutine omp_target_enter_depend
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_enter_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
integer :: a(1024)

!CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp task depend(out: a)
call foo(a)
!$omp end task
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!$omp target enter data map(to: a) depend(in: a)
return
end subroutine omp_target_enter_depend

!===============================================================================
! Target_Enter Map types
!===============================================================================
Expand Down Expand Up @@ -134,6 +154,45 @@ subroutine omp_target_exit_device
!$omp target exit data map(from: a) device(d)
end subroutine omp_target_exit_device

!===============================================================================
! Target_Exit `depend` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_exit_depend() {
subroutine omp_target_exit_depend
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_exit_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
integer :: a(1024)
!CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp task depend(out: a)
call foo(a)
!$omp end task
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!$omp target exit data map(from: a) depend(out: a)
end subroutine omp_target_exit_depend


!===============================================================================
! Target_Update `depend` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_update_depend() {
subroutine omp_target_update_depend
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_update_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
integer :: a(1024)

!CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp task depend(out: a)
call foo(a)
!$omp end task

!CHECK: %[[BOUNDS:.*]] = omp.bounds
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!$omp target update to(a) depend(in:a)
end subroutine omp_target_update_depend

!===============================================================================
! Target_Update `to` clause
!===============================================================================
Expand Down Expand Up @@ -295,6 +354,32 @@ subroutine omp_target
!CHECK: }
end subroutine omp_target

!===============================================================================
! Target with region `depend` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_depend() {
subroutine omp_target_depend
!CHECK: %[[EXTENT_A:.*]] = arith.constant 1024 : index
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
integer :: a(1024)
!CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp task depend(out: a)
call foo(a)
!$omp end task
!CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index
!CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index
!CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index
!CHECK: %[[BOUNDS_A:.*]] = omp.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index)
!CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp target map(tofrom: a) depend(in: a)
a(1) = 10
!CHECK: omp.terminator
!$omp end target
!CHECK: }
end subroutine omp_target_depend

!===============================================================================
! Target implicit capture
!===============================================================================
Expand Down
1 change: 1 addition & 0 deletions flang/test/Semantics/OpenMP/clause-validity01.f90
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@
!$omp taskyield
!$omp barrier
!$omp taskwait
!ERROR: DEPEND(SOURCE) or DEPEND(SINK : vec) can be used only with the ordered directive. Used here in the TASKWAIT construct.
!$omp taskwait depend(source)
! !$omp taskwait depend(sink:i-1)
! !$omp target enter data map(to:arrayA) map(alloc:arrayB)
Expand Down
22 changes: 22 additions & 0 deletions flang/test/Semantics/OpenMP/task_depend_source.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
! RUN: %python %S/../test_errors.py %s %flang -fopenmp
! OpenMP Version 4.5
! 2.13.9 Depend Clause
! depend(source) can be used only with the ordered construct
program main
implicit none
integer :: number = 0

!ERROR: DEPEND(SOURCE) or DEPEND(SINK : vec) can be used only with the ordered directive. Used here in the TASK construct.
!$omp task depend(source)
number = 1
!$omp end task


!$omp task
number = number + 1
!$omp end task

!$omp task
print*, number
!$omp end task
end program main
37 changes: 32 additions & 5 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;

def ClauseTaskDepend : I32EnumAttr<
"ClauseTaskDepend",
"task depend clause",
"depend clause in a target or task construct",
[ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::omp";
Expand Down Expand Up @@ -1351,11 +1351,17 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",

The $map_types specifies the types and modifiers for the map clause.

TODO: depend clause and map_type_modifier values iterator and mapper.
The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

TODO: map_type_modifier values iterator and mapper.
}];

let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);

Expand All @@ -1364,6 +1370,7 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];

Expand Down Expand Up @@ -1397,11 +1404,17 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",

The $map_types specifies the types and modifiers for the map clause.

TODO: depend clause and map_type_modifier values iterator and mapper.
The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

TODO: map_type_modifier values iterator and mapper.
}];

let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);

Expand All @@ -1410,6 +1423,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];

Expand Down Expand Up @@ -1447,11 +1461,16 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
during verification to make sure the restrictions for target update are
respected.

TODO: depend clause
The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

}];

let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$motion_operands);

Expand All @@ -1460,6 +1479,7 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `motion_entries` `(` $motion_operands `:` type($motion_operands) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];

Expand Down Expand Up @@ -1488,13 +1508,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
The optional $nowait elliminates the implicit barrier so the parent task can make progress
even if the target task is not yet completed.

TODO: is_device_ptr, depend, defaultmap, in_reduction
The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

TODO: is_device_ptr, defaultmap, in_reduction

}];

let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
Optional<AnyInteger>:$thread_limit,
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);

Expand All @@ -1506,6 +1532,7 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterfa
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `nowait` $nowait
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) $region attr-dict
}];

Expand Down
22 changes: 17 additions & 5 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ static LogicalResult verifyDependVarList(Operation *op,
return op->emitOpError() << "expected as many depend values"
" as depend variables";
} else {
if (depends)
if (depends && !depends->empty())
return op->emitOpError() << "unexpected depend values";
return success();
}
Expand Down Expand Up @@ -965,19 +965,31 @@ LogicalResult DataOp::verify() {
}

LogicalResult EnterDataOp::verify() {
return verifyMapClause(*this, getMapOperands());
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
return failed(verifyDependVars) ? verifyDependVars
: verifyMapClause(*this, getMapOperands());
}

LogicalResult ExitDataOp::verify() {
return verifyMapClause(*this, getMapOperands());
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
return failed(verifyDependVars) ? verifyDependVars
: verifyMapClause(*this, getMapOperands());
}

LogicalResult UpdateDataOp::verify() {
return verifyMapClause(*this, getMotionOperands());
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
return failed(verifyDependVars) ? verifyDependVars
: verifyMapClause(*this, getMotionOperands());
}

LogicalResult TargetOp::verify() {
return verifyMapClause(*this, getMapOperands());
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
return failed(verifyDependVars) ? verifyDependVars
: verifyMapClause(*this, getMapOperands());
}

//===----------------------------------------------------------------------===//
Expand Down
Loading