Skip to content

[flang][OpenMP][Lower] lower array subscripts for task depend #132994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "Clauses.h"
#include "Utils.h"

#include "flang/Lower/ConvertExprToHLFIR.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/tools.h"
Expand Down Expand Up @@ -808,7 +809,21 @@ bool ClauseProcessor::processCopyprivate(
return hasCopyPrivate;
}

bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
template <typename T>
static bool isVectorSubscript(const evaluate::Expr<T> &expr) {
if (std::optional<evaluate::DataRef> dataRef{evaluate::ExtractDataRef(expr)})
if (const auto *arrayRef = std::get_if<evaluate::ArrayRef>(&dataRef->u))
for (const evaluate::Subscript &subscript : arrayRef->subscript())
if (std::holds_alternative<evaluate::IndirectSubscriptIntegerExpr>(
subscript.u))
if (subscript.Rank() > 0)
return true;
return false;
}

bool ClauseProcessor::processDepend(lower::SymMap &symMap,
lower::StatementContext &stmtCtx,
mlir::omp::DependClauseOps &result) const {
auto process = [&](const omp::clause::Depend &clause,
const parser::CharBlock &) {
using Depend = omp::clause::Depend;
Expand All @@ -819,6 +834,7 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
auto &taskDep = std::get<Depend::TaskDep>(clause.u);
auto depType = std::get<clause::DependenceType>(taskDep.t);
auto &objects = std::get<omp::ObjectList>(taskDep.t);
fir::FirOpBuilder &builder = converter.getFirOpBuilder();

if (std::get<std::optional<omp::clause::Iterator>>(taskDep.t)) {
TODO(converter.getCurrentLocation(),
Expand All @@ -830,18 +846,46 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {

for (const omp::Object &object : objects) {
assert(object.ref() && "Expecting designator");
mlir::Value dependVar;

if (evaluate::ExtractSubstring(*object.ref())) {
TODO(converter.getCurrentLocation(),
"substring not supported for task depend");
} else if (evaluate::IsArrayElement(*object.ref())) {
TODO(converter.getCurrentLocation(),
"array sections not supported for task depend");
// Array Section
SomeExpr expr = *object.ref();
if (isVectorSubscript(expr))
TODO(converter.getCurrentLocation(),
"Vector subscripted array section for task dependency");

hlfir::EntityWithAttributes entity = convertExprToHLFIR(
converter.getCurrentLocation(), converter, expr, symMap, stmtCtx);
dependVar = entity.getBase();
} else {
semantics::Symbol *sym = object.sym();
dependVar = converter.getSymbolAddress(*sym);
}

semantics::Symbol *sym = object.sym();
const mlir::Value variable = converter.getSymbolAddress(*sym);
result.dependVars.push_back(variable);
// If we pass a mutable box e.g. !fir.ref<!fir.box<!fir.heap<...>>> then
// the runtime will use the address of the box not the address of the
// data. Flang generates a lot of memcpys between different box
// allocations so this is not a reliable way to identify the dependency.
if (auto ref = mlir::dyn_cast<fir::ReferenceType>(dependVar.getType()))
if (fir::isa_box_type(ref.getElementType()))
dependVar = builder.create<fir::LoadOp>(
converter.getCurrentLocation(), dependVar);

// The openmp dialect doesn't know what to do with boxes (and it would
// break layering to teach it about them). The dependency variable can be
// a box because it was an array section or because the original symbol
// was mapped to a box.
// Getting the address of the box data is okay because all the runtime
// ultimately cares about is the base address of the array.
if (fir::isa_box_type(dependVar.getType()))
dependVar = builder.create<fir::BoxAddrOp>(
converter.getCurrentLocation(), dependVar);

result.dependVars.push_back(dependVar);
}
};

Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class ClauseProcessor {
bool processCopyin() const;
bool processCopyprivate(mlir::Location currentLocation,
mlir::omp::CopyprivateClauseOps &result) const;
bool processDepend(mlir::omp::DependClauseOps &result) const;
bool processDepend(lower::SymMap &symMap, lower::StatementContext &stmtCtx,
mlir::omp::DependClauseOps &result) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
Expand Down
27 changes: 15 additions & 12 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1669,15 +1669,15 @@ static void genSingleClauses(lower::AbstractConverter &converter,

static void genTargetClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TargetOperands &clauseOps,
lower::SymMap &symTable, lower::StatementContext &stmtCtx,
lower::pft::Evaluation &eval, const List<Clause> &clauses,
mlir::Location loc, mlir::omp::TargetOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processBare(clauseOps);
cp.processDepend(clauseOps);
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
if (!hostEvalInfo.empty()) {
Expand Down Expand Up @@ -1728,11 +1728,12 @@ static void genTargetDataClauses(

static void genTargetEnterExitUpdateDataClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, llvm::omp::Directive directive,
lower::SymMap &symTable, lower::StatementContext &stmtCtx,
const List<Clause> &clauses, mlir::Location loc,
llvm::omp::Directive directive,
mlir::omp::TargetEnterExitUpdateDataOperands &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDepend(clauseOps);
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processIf(directive, clauseOps);

Expand All @@ -1746,12 +1747,13 @@ static void genTargetEnterExitUpdateDataClauses(

static void genTaskClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
lower::SymMap &symTable,
lower::StatementContext &stmtCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskOperands &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processDepend(clauseOps);
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processFinal(stmtCtx, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
cp.processMergeable(clauseOps);
Expand Down Expand Up @@ -2194,8 +2196,8 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::omp::TargetOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> mapSyms, isDevicePtrSyms,
hasDeviceAddrSyms;
genTargetClauses(converter, semaCtx, stmtCtx, eval, item->clauses, loc,
clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms);
genTargetClauses(converter, semaCtx, symTable, stmtCtx, eval, item->clauses,
loc, clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms);

DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/
Expand Down Expand Up @@ -2415,7 +2417,7 @@ static OpTy genTargetEnterExitUpdateDataOp(
}

mlir::omp::TargetEnterExitUpdateDataOperands clauseOps;
genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx,
genTargetEnterExitUpdateDataClauses(converter, semaCtx, symTable, stmtCtx,
item->clauses, loc, directive, clauseOps);

return firOpBuilder.create<OpTy>(loc, clauseOps);
Expand All @@ -2428,7 +2430,8 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
lower::StatementContext stmtCtx;
Copy link
Contributor Author

@tblah tblah Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While working on the next patch in this series I've noticed that this stmtCtx seems not to be scoped correctly. If any cleanup code is added to it, it will emit that cleanup code when it is destroyed. Unfortunately, when this goes out of scope at the end of genTaskOp the builder insertion point is inside of the task operation. It would make more sense to do the cleanup after the end of the task operation to avoid adding unnecessary shared variables (this might be only a theoretical bug because I haven't yet found an example where any cleanup code is still present by the time we reach FIR.)

I have a proof of concept fix which I will finish off on Monday.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to understand the problem here. This stmtCtx is used in managing and cleaning up intermediates while generating expressions of certain clauses. At the time the cleanup code is emitted, the insertion point is inside the task's region. Is it after the operations in task's region are generated, or before that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss this on the PR for this change. Hopefully the example will make it clearer: #133891 (comment)

mlir::omp::TaskOperands clauseOps;
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
genTaskClauses(converter, semaCtx, symTable, stmtCtx, item->clauses, loc,
clauseOps);

if (!enableDelayedPrivatization)
return genOpWithBody<mlir::omp::TaskOp>(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
! RUN: %not_todo_cmd bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s

! CHECK: Vector subscripted array section for task dependency
subroutine vectorSubscriptArraySection(array, indices)
integer :: array(:)
integer :: indices(:)

!$omp task depend (in: array(indices))
!$omp end task
end subroutine
18 changes: 18 additions & 0 deletions flang/test/Lower/OpenMP/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,24 @@ subroutine omp_target_enter_depend
return
end subroutine omp_target_enter_depend

!CHECK-LABEL: func.func @_QPomp_target_enter_depend_section() {
subroutine omp_target_enter_depend_section
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_enter_depend_sectionEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
integer :: a(1024)

!CHECK: %[[DESIGNATE:.*]] = hlfir.designate %[[A]]#0 ({{.*}}) shape %{{.*}} : (!fir.ref<!fir.array<1024xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<512xi32>>
!CHECK: omp.task depend(taskdependout -> %[[DESIGNATE]] : !fir.ref<!fir.array<512xi32>>) private({{.*}}) {
!$omp task depend(out: a(1:512))
call foo(a)
!$omp end task
!CHECK: %[[DESIGNATE2:.*]] = hlfir.designate %[[A]]#0 ({{.*}}) shape %{{.*}} : (!fir.ref<!fir.array<1024xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<512xi32>>
!CHECK: %[[BOUNDS:.*]] = omp.map.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
!CHECK: omp.target_enter_data depend(taskdependin -> %[[DESIGNATE2]] : !fir.ref<!fir.array<512xi32>>) map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target enter data map(to: a) depend(in: a(1:512))
return
end subroutine omp_target_enter_depend_section

!===============================================================================
! Target_Enter Map types
!===============================================================================
Expand Down
51 changes: 51 additions & 0 deletions flang/test/Lower/OpenMP/task-depend-array-section.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

subroutine knownShape(array)
integer :: array(10)

!$omp task depend(in: array(2:8))
!$omp end task
end subroutine

! CHECK-LABEL: func.func @_QPknownshape(
! CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "array"}) {
! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
! CHECK: %[[VAL_2:.*]] = arith.constant 10 : index
! CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
! CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) dummy_scope %[[VAL_1]] {uniq_name = "_QFknownshapeEarray"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
! CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
! CHECK: %[[VAL_6:.*]] = arith.constant 8 : index
! CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_8:.*]] = arith.constant 7 : index
! CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
! CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_5]]:%[[VAL_6]]:%[[VAL_7]]) shape %[[VAL_9]] : (!fir.ref<!fir.array<10xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<7xi32>>
! CHECK: omp.task depend(taskdependin -> %[[VAL_10]] : !fir.ref<!fir.array<7xi32>>) {
! CHECK: omp.terminator
! CHECK: }
! CHECK: return
! CHECK: }


subroutine assumedShape(array)
integer :: array(:)

!$omp task depend(in: array(2:8:2))
!$omp end task
end subroutine

! CHECK-LABEL: func.func @_QPassumedshape(
! CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}) {
! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]] {uniq_name = "_QFassumedshapeEarray"} : (!fir.box<!fir.array<?xi32>>, !fir.dscope) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
! CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
! CHECK: %[[VAL_4:.*]] = arith.constant 8 : index
! CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
! CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
! CHECK: %[[VAL_7:.*]] = fir.shape %[[VAL_6]] : (index) -> !fir.shape<1>
! CHECK: %[[VAL_8:.*]] = hlfir.designate %[[VAL_2]]#0 (%[[VAL_3]]:%[[VAL_4]]:%[[VAL_5]]) shape %[[VAL_7]] : (!fir.box<!fir.array<?xi32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<4xi32>>
! CHECK: %[[VAL_9:.*]] = fir.box_addr %[[VAL_8]] : (!fir.box<!fir.array<4xi32>>) -> !fir.ref<!fir.array<4xi32>>
! CHECK: omp.task depend(taskdependin -> %[[VAL_9]] : !fir.ref<!fir.array<4xi32>>) {
! CHECK: omp.terminator
! CHECK: }
! CHECK: return
! CHECK: }
19 changes: 18 additions & 1 deletion flang/test/Lower/OpenMP/task.f90
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ subroutine task_depend_non_int()
character(len = 15) :: x
integer, allocatable :: y
complex :: z
!CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<!fir.char<1,15>>, taskdependin -> %{{.+}} : !fir.ref<!fir.box<!fir.heap<i32>>>, taskdependin -> %{{.+}} : !fir.ref<complex<f32>>) {
!CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<!fir.char<1,15>>, taskdependin -> %{{.+}} : !fir.heap<i32>, taskdependin -> %{{.+}} : !fir.ref<complex<f32>>) {
!$omp task depend(in : x, y, z)
!CHECK: omp.terminator
!$omp end task
Expand Down Expand Up @@ -158,6 +158,23 @@ subroutine task_depend_multi_task()
!$omp end task
end subroutine task_depend_multi_task

subroutine task_depend_box(array)
integer :: array(:)
!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %{{.*}} : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
!CHECK: omp.task depend(taskdependin -> %[[BOX_ADDR]] : !fir.ref<!fir.array<?xi32>>)
!$omp task depend(in: array)
!$omp end task
end subroutine

subroutine task_depend_mutable_box(alloc)
integer, allocatable :: alloc
!CHECK: %[[LOAD:.*]] = fir.load %{{.*}} : !fir.ref<!fir.box<!fir.heap<i32>>>
!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[LOAD]] : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
!CHECK: omp.task depend(taskdependin -> %[[BOX_ADDR]] : !fir.heap<i32>)
!$omp task depend(in: alloc)
!$omp end task
end subroutine

!===============================================================================
! `private` clause
!===============================================================================
Expand Down