Skip to content

[flang][openmp] - depend clause support in target, target enter/update/exit data constructs #81610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand;
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Value> mapOperands;
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;

Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
llvm::omp::Directive directive;
Expand All @@ -784,6 +785,7 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(directiveName, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processDepend(dependTypeOperands, dependOperands);
cp.processNowait(nowaitAttr);

if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
Expand All @@ -796,12 +798,13 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
}

cp.processTODO<Fortran::parser::OmpClause::Depend>(currentLocation,
directive);

return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
deviceOperand, nullptr, mlir::ValueRange(),
nowaitAttr, mapOperands);
return firOpBuilder.create<OpTy>(
currentLocation, ifClauseOperand, deviceOperand,
dependTypeOperands.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
dependTypeOperands),
dependOperands, nowaitAttr, mapOperands);
}

// This functions creates a block for the body of the targetOp's region. It adds
Expand Down Expand Up @@ -968,7 +971,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Value> mapOperands;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
llvm::SmallVector<mlir::Type> mapSymTypes;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
Expand All @@ -978,11 +982,12 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLimitOperand);
cp.processDepend(dependTypeOperands, dependOperands);
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
&mapSymLocs, &mapSymbols);

cp.processTODO<Fortran::parser::OmpClause::Private,
Fortran::parser::OmpClause::Depend,
Fortran::parser::OmpClause::Firstprivate,
Fortran::parser::OmpClause::IsDevicePtr,
Fortran::parser::OmpClause::HasDeviceAddr,
Expand All @@ -992,7 +997,6 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::parser::OmpClause::UsesAllocators,
Fortran::parser::OmpClause::Defaultmap>(
currentLocation, llvm::omp::Directive::OMPD_target);

// 5.8.1 Implicit Data-Mapping Attribute Rules
// The following code follows the implicit data-mapping rules to map all the
// symbols used inside the region that have not been explicitly mapped using
Expand Down Expand Up @@ -1066,7 +1070,11 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,

auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
dependTypeOperands.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
dependTypeOperands),
dependOperands, nowaitAttr, mapOperands);

genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
mapSymLocs, mapSymbols, currentLocation);
Expand Down
8 changes: 8 additions & 0 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2815,6 +2815,14 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Device &x) {

void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
CheckAllowed(llvm::omp::Clause::OMPC_depend);
if ((std::holds_alternative<parser::OmpDependClause::Source>(x.v.u) ||
std::holds_alternative<parser::OmpDependClause::Sink>(x.v.u)) &&
GetContext().directive != llvm::omp::OMPD_ordered) {
context_.Say(GetContext().clauseSource,
"DEPEND(SOURCE) or DEPEND(SINK : vec) can be used only with the ordered"
" directive. Used here in the %s construct."_err_en_US,
parser::ToUpperCaseLetters(getDirectiveName(GetContext().directive)));
}
if (const auto *inOut{std::get_if<parser::OmpDependClause::InOut>(&x.v.u)}) {
const auto &designators{std::get<std::list<parser::Designator>>(inOut->t)};
for (const auto &ele : designators) {
Expand Down
85 changes: 85 additions & 0 deletions flang/test/Lower/OpenMP/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ subroutine omp_target_enter_simple
return
end subroutine omp_target_enter_simple

!===============================================================================
! Target_Enter `depend` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_enter_depend() {
subroutine omp_target_enter_depend
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_enter_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think it's not necessary to check the hlfir.declare operation here, since we're only really interested in matching the same SSA value in omp.task depend(taskdependout -> %[[A:.*]] : ... and in omp.target_enter_data ... depend(taskdependin -> %[[A]] : .... Same comment for the other tests added in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely agree. Even I am not a fan of adding CHECKs for aspects that are not intended to be affected by the patch/PR in question. I did it this way only to be consistent with the rest of the this test file. I was bemused by that myself.

integer :: a(1024)

!CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp task depend(out: a)
call foo(a)
!$omp end task
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!$omp target enter data map(to: a) depend(in: a)
return
end subroutine omp_target_enter_depend

!===============================================================================
! Target_Enter Map types
!===============================================================================
Expand Down Expand Up @@ -134,6 +154,45 @@ subroutine omp_target_exit_device
!$omp target exit data map(from: a) device(d)
end subroutine omp_target_exit_device

!===============================================================================
! Target_Exit `depend` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_exit_depend() {
subroutine omp_target_exit_depend
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_exit_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
integer :: a(1024)
!CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp task depend(out: a)
call foo(a)
!$omp end task
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!$omp target exit data map(from: a) depend(out: a)
end subroutine omp_target_exit_depend


!===============================================================================
! Target_Update `depend` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_update_depend() {
subroutine omp_target_update_depend
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_update_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
integer :: a(1024)

!CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp task depend(out: a)
call foo(a)
!$omp end task

!CHECK: %[[BOUNDS:.*]] = omp.bounds
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!$omp target update to(a) depend(in:a)
end subroutine omp_target_update_depend

!===============================================================================
! Target_Update `to` clause
!===============================================================================
Expand Down Expand Up @@ -295,6 +354,32 @@ subroutine omp_target
!CHECK: }
end subroutine omp_target

!===============================================================================
! Target with region `depend` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_depend() {
subroutine omp_target_depend
!CHECK: %[[EXTENT_A:.*]] = arith.constant 1024 : index
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
integer :: a(1024)
!CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp task depend(out: a)
call foo(a)
!$omp end task
!CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index
!CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index
!CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index
!CHECK: %[[BOUNDS_A:.*]] = omp.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index)
!CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!$omp target map(tofrom: a) depend(in: a)
a(1) = 10
!CHECK: omp.terminator
!$omp end target
!CHECK: }
end subroutine omp_target_depend

!===============================================================================
! Target implicit capture
!===============================================================================
Expand Down
1 change: 1 addition & 0 deletions flang/test/Semantics/OpenMP/clause-validity01.f90
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@
!$omp taskyield
!$omp barrier
!$omp taskwait
!ERROR: DEPEND(SOURCE) or DEPEND(SINK : vec) can be used only with the ordered directive. Used here in the TASKWAIT construct.
!$omp taskwait depend(source)
! !$omp taskwait depend(sink:i-1)
! !$omp target enter data map(to:arrayA) map(alloc:arrayB)
Expand Down