Skip to content

Commit e17d864

Browse files
authored
[flang][OpenMP][Lower] lower array subscripts for task depend (#132994)
The OpenMP standard says that all dependencies in the same set of inter-dependent tasks must be non-overlapping. This simplification means that the OpenMP only needs to keep track of the base addresses of dependency variables. This can be seen in kmp_taskdeps.cpp, which stores task dependency information in a hash table, using the base address as a key. This patch generates a rebox operation to slice boxed arrays, but only the box data address is used for the task dependency. The extra box is optimized away by LLVM at O3. Vector subscripts are TODO (I will address in my next patch). This also fixes a bug for ordinary subscripts when the symbol was mapped to a box: Fixes #132647
1 parent dca7e03 commit e17d864

File tree

7 files changed

+165
-20
lines changed

7 files changed

+165
-20
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "Clauses.h"
1515
#include "Utils.h"
1616

17+
#include "flang/Lower/ConvertExprToHLFIR.h"
1718
#include "flang/Lower/PFTBuilder.h"
1819
#include "flang/Parser/tools.h"
1920
#include "flang/Semantics/tools.h"
@@ -808,7 +809,21 @@ bool ClauseProcessor::processCopyprivate(
808809
return hasCopyPrivate;
809810
}
810811

811-
bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
812+
template <typename T>
813+
static bool isVectorSubscript(const evaluate::Expr<T> &expr) {
814+
if (std::optional<evaluate::DataRef> dataRef{evaluate::ExtractDataRef(expr)})
815+
if (const auto *arrayRef = std::get_if<evaluate::ArrayRef>(&dataRef->u))
816+
for (const evaluate::Subscript &subscript : arrayRef->subscript())
817+
if (std::holds_alternative<evaluate::IndirectSubscriptIntegerExpr>(
818+
subscript.u))
819+
if (subscript.Rank() > 0)
820+
return true;
821+
return false;
822+
}
823+
824+
bool ClauseProcessor::processDepend(lower::SymMap &symMap,
825+
lower::StatementContext &stmtCtx,
826+
mlir::omp::DependClauseOps &result) const {
812827
auto process = [&](const omp::clause::Depend &clause,
813828
const parser::CharBlock &) {
814829
using Depend = omp::clause::Depend;
@@ -819,6 +834,7 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
819834
auto &taskDep = std::get<Depend::TaskDep>(clause.u);
820835
auto depType = std::get<clause::DependenceType>(taskDep.t);
821836
auto &objects = std::get<omp::ObjectList>(taskDep.t);
837+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
822838

823839
if (std::get<std::optional<omp::clause::Iterator>>(taskDep.t)) {
824840
TODO(converter.getCurrentLocation(),
@@ -830,18 +846,46 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
830846

831847
for (const omp::Object &object : objects) {
832848
assert(object.ref() && "Expecting designator");
849+
mlir::Value dependVar;
833850

834851
if (evaluate::ExtractSubstring(*object.ref())) {
835852
TODO(converter.getCurrentLocation(),
836853
"substring not supported for task depend");
837854
} else if (evaluate::IsArrayElement(*object.ref())) {
838-
TODO(converter.getCurrentLocation(),
839-
"array sections not supported for task depend");
855+
// Array Section
856+
SomeExpr expr = *object.ref();
857+
if (isVectorSubscript(expr))
858+
TODO(converter.getCurrentLocation(),
859+
"Vector subscripted array section for task dependency");
860+
861+
hlfir::EntityWithAttributes entity = convertExprToHLFIR(
862+
converter.getCurrentLocation(), converter, expr, symMap, stmtCtx);
863+
dependVar = entity.getBase();
864+
} else {
865+
semantics::Symbol *sym = object.sym();
866+
dependVar = converter.getSymbolAddress(*sym);
840867
}
841868

842-
semantics::Symbol *sym = object.sym();
843-
const mlir::Value variable = converter.getSymbolAddress(*sym);
844-
result.dependVars.push_back(variable);
869+
// If we pass a mutable box e.g. !fir.ref<!fir.box<!fir.heap<...>>> then
870+
// the runtime will use the address of the box not the address of the
871+
// data. Flang generates a lot of memcpys between different box
872+
// allocations so this is not a reliable way to identify the dependency.
873+
if (auto ref = mlir::dyn_cast<fir::ReferenceType>(dependVar.getType()))
874+
if (fir::isa_box_type(ref.getElementType()))
875+
dependVar = builder.create<fir::LoadOp>(
876+
converter.getCurrentLocation(), dependVar);
877+
878+
// The openmp dialect doesn't know what to do with boxes (and it would
879+
// break layering to teach it about them). The dependency variable can be
880+
// a box because it was an array section or because the original symbol
881+
// was mapped to a box.
882+
// Getting the address of the box data is okay because all the runtime
883+
// ultimately cares about is the base address of the array.
884+
if (fir::isa_box_type(dependVar.getType()))
885+
dependVar = builder.create<fir::BoxAddrOp>(
886+
converter.getCurrentLocation(), dependVar);
887+
888+
result.dependVars.push_back(dependVar);
845889
}
846890
};
847891

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ class ClauseProcessor {
104104
bool processCopyin() const;
105105
bool processCopyprivate(mlir::Location currentLocation,
106106
mlir::omp::CopyprivateClauseOps &result) const;
107-
bool processDepend(mlir::omp::DependClauseOps &result) const;
107+
bool processDepend(lower::SymMap &symMap, lower::StatementContext &stmtCtx,
108+
mlir::omp::DependClauseOps &result) const;
108109
bool
109110
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
110111
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,15 +1672,15 @@ static void genSingleClauses(lower::AbstractConverter &converter,
16721672

16731673
static void genTargetClauses(
16741674
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1675-
lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval,
1676-
const List<Clause> &clauses, mlir::Location loc,
1677-
mlir::omp::TargetOperands &clauseOps,
1675+
lower::SymMap &symTable, lower::StatementContext &stmtCtx,
1676+
lower::pft::Evaluation &eval, const List<Clause> &clauses,
1677+
mlir::Location loc, mlir::omp::TargetOperands &clauseOps,
16781678
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
16791679
llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
16801680
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
16811681
ClauseProcessor cp(converter, semaCtx, clauses);
16821682
cp.processBare(clauseOps);
1683-
cp.processDepend(clauseOps);
1683+
cp.processDepend(symTable, stmtCtx, clauseOps);
16841684
cp.processDevice(stmtCtx, clauseOps);
16851685
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
16861686
if (!hostEvalInfo.empty()) {
@@ -1731,11 +1731,12 @@ static void genTargetDataClauses(
17311731

17321732
static void genTargetEnterExitUpdateDataClauses(
17331733
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1734-
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
1735-
mlir::Location loc, llvm::omp::Directive directive,
1734+
lower::SymMap &symTable, lower::StatementContext &stmtCtx,
1735+
const List<Clause> &clauses, mlir::Location loc,
1736+
llvm::omp::Directive directive,
17361737
mlir::omp::TargetEnterExitUpdateDataOperands &clauseOps) {
17371738
ClauseProcessor cp(converter, semaCtx, clauses);
1738-
cp.processDepend(clauseOps);
1739+
cp.processDepend(symTable, stmtCtx, clauseOps);
17391740
cp.processDevice(stmtCtx, clauseOps);
17401741
cp.processIf(directive, clauseOps);
17411742

@@ -1749,12 +1750,13 @@ static void genTargetEnterExitUpdateDataClauses(
17491750

17501751
static void genTaskClauses(lower::AbstractConverter &converter,
17511752
semantics::SemanticsContext &semaCtx,
1753+
lower::SymMap &symTable,
17521754
lower::StatementContext &stmtCtx,
17531755
const List<Clause> &clauses, mlir::Location loc,
17541756
mlir::omp::TaskOperands &clauseOps) {
17551757
ClauseProcessor cp(converter, semaCtx, clauses);
17561758
cp.processAllocate(clauseOps);
1757-
cp.processDepend(clauseOps);
1759+
cp.processDepend(symTable, stmtCtx, clauseOps);
17581760
cp.processFinal(stmtCtx, clauseOps);
17591761
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
17601762
cp.processMergeable(clauseOps);
@@ -2197,8 +2199,8 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
21972199
mlir::omp::TargetOperands clauseOps;
21982200
llvm::SmallVector<const semantics::Symbol *> mapSyms, isDevicePtrSyms,
21992201
hasDeviceAddrSyms;
2200-
genTargetClauses(converter, semaCtx, stmtCtx, eval, item->clauses, loc,
2201-
clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms);
2202+
genTargetClauses(converter, semaCtx, symTable, stmtCtx, eval, item->clauses,
2203+
loc, clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms);
22022204

22032205
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
22042206
/*shouldCollectPreDeterminedSymbols=*/
@@ -2418,7 +2420,7 @@ static OpTy genTargetEnterExitUpdateDataOp(
24182420
}
24192421

24202422
mlir::omp::TargetEnterExitUpdateDataOperands clauseOps;
2421-
genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx,
2423+
genTargetEnterExitUpdateDataClauses(converter, semaCtx, symTable, stmtCtx,
24222424
item->clauses, loc, directive, clauseOps);
24232425

24242426
return firOpBuilder.create<OpTy>(loc, clauseOps);
@@ -2431,7 +2433,8 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24312433
ConstructQueue::const_iterator item) {
24322434
lower::StatementContext stmtCtx;
24332435
mlir::omp::TaskOperands clauseOps;
2434-
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
2436+
genTaskClauses(converter, semaCtx, symTable, stmtCtx, item->clauses, loc,
2437+
clauseOps);
24352438

24362439
if (!enableDelayedPrivatization)
24372440
return genOpWithBody<mlir::omp::TaskOp>(
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
! RUN: %not_todo_cmd bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
2+
! RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
3+
4+
! CHECK: Vector subscripted array section for task dependency
5+
subroutine vectorSubscriptArraySection(array, indices)
6+
integer :: array(:)
7+
integer :: indices(:)
8+
9+
!$omp task depend (in: array(indices))
10+
!$omp end task
11+
end subroutine

flang/test/Lower/OpenMP/target.f90

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,24 @@ subroutine omp_target_enter_depend
3535
return
3636
end subroutine omp_target_enter_depend
3737

38+
!CHECK-LABEL: func.func @_QPomp_target_enter_depend_section() {
39+
subroutine omp_target_enter_depend_section
40+
!CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_enter_depend_sectionEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
41+
integer :: a(1024)
42+
43+
!CHECK: %[[DESIGNATE:.*]] = hlfir.designate %[[A]]#0 ({{.*}}) shape %{{.*}} : (!fir.ref<!fir.array<1024xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<512xi32>>
44+
!CHECK: omp.task depend(taskdependout -> %[[DESIGNATE]] : !fir.ref<!fir.array<512xi32>>) private({{.*}}) {
45+
!$omp task depend(out: a(1:512))
46+
call foo(a)
47+
!$omp end task
48+
!CHECK: %[[DESIGNATE2:.*]] = hlfir.designate %[[A]]#0 ({{.*}}) shape %{{.*}} : (!fir.ref<!fir.array<1024xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<512xi32>>
49+
!CHECK: %[[BOUNDS:.*]] = omp.map.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
50+
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
51+
!CHECK: omp.target_enter_data depend(taskdependin -> %[[DESIGNATE2]] : !fir.ref<!fir.array<512xi32>>) map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
52+
!$omp target enter data map(to: a) depend(in: a(1:512))
53+
return
54+
end subroutine omp_target_enter_depend_section
55+
3856
!===============================================================================
3957
! Target_Enter Map types
4058
!===============================================================================
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
2+
3+
subroutine knownShape(array)
4+
integer :: array(10)
5+
6+
!$omp task depend(in: array(2:8))
7+
!$omp end task
8+
end subroutine
9+
10+
! CHECK-LABEL: func.func @_QPknownshape(
11+
! CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "array"}) {
12+
! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
13+
! CHECK: %[[VAL_2:.*]] = arith.constant 10 : index
14+
! CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
15+
! CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) dummy_scope %[[VAL_1]] {uniq_name = "_QFknownshapeEarray"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
16+
! CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
17+
! CHECK: %[[VAL_6:.*]] = arith.constant 8 : index
18+
! CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
19+
! CHECK: %[[VAL_8:.*]] = arith.constant 7 : index
20+
! CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
21+
! CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_5]]:%[[VAL_6]]:%[[VAL_7]]) shape %[[VAL_9]] : (!fir.ref<!fir.array<10xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<7xi32>>
22+
! CHECK: omp.task depend(taskdependin -> %[[VAL_10]] : !fir.ref<!fir.array<7xi32>>) {
23+
! CHECK: omp.terminator
24+
! CHECK: }
25+
! CHECK: return
26+
! CHECK: }
27+
28+
29+
subroutine assumedShape(array)
30+
integer :: array(:)
31+
32+
!$omp task depend(in: array(2:8:2))
33+
!$omp end task
34+
end subroutine
35+
36+
! CHECK-LABEL: func.func @_QPassumedshape(
37+
! CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}) {
38+
! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
39+
! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]] {uniq_name = "_QFassumedshapeEarray"} : (!fir.box<!fir.array<?xi32>>, !fir.dscope) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
40+
! CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
41+
! CHECK: %[[VAL_4:.*]] = arith.constant 8 : index
42+
! CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
43+
! CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
44+
! CHECK: %[[VAL_7:.*]] = fir.shape %[[VAL_6]] : (index) -> !fir.shape<1>
45+
! CHECK: %[[VAL_8:.*]] = hlfir.designate %[[VAL_2]]#0 (%[[VAL_3]]:%[[VAL_4]]:%[[VAL_5]]) shape %[[VAL_7]] : (!fir.box<!fir.array<?xi32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<4xi32>>
46+
! CHECK: %[[VAL_9:.*]] = fir.box_addr %[[VAL_8]] : (!fir.box<!fir.array<4xi32>>) -> !fir.ref<!fir.array<4xi32>>
47+
! CHECK: omp.task depend(taskdependin -> %[[VAL_9]] : !fir.ref<!fir.array<4xi32>>) {
48+
! CHECK: omp.terminator
49+
! CHECK: }
50+
! CHECK: return
51+
! CHECK: }

flang/test/Lower/OpenMP/task.f90

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ subroutine task_depend_non_int()
9393
character(len = 15) :: x
9494
integer, allocatable :: y
9595
complex :: z
96-
!CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<!fir.char<1,15>>, taskdependin -> %{{.+}} : !fir.ref<!fir.box<!fir.heap<i32>>>, taskdependin -> %{{.+}} : !fir.ref<complex<f32>>) {
96+
!CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<!fir.char<1,15>>, taskdependin -> %{{.+}} : !fir.heap<i32>, taskdependin -> %{{.+}} : !fir.ref<complex<f32>>) {
9797
!$omp task depend(in : x, y, z)
9898
!CHECK: omp.terminator
9999
!$omp end task
@@ -158,6 +158,23 @@ subroutine task_depend_multi_task()
158158
!$omp end task
159159
end subroutine task_depend_multi_task
160160

161+
subroutine task_depend_box(array)
162+
integer :: array(:)
163+
!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %{{.*}} : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
164+
!CHECK: omp.task depend(taskdependin -> %[[BOX_ADDR]] : !fir.ref<!fir.array<?xi32>>)
165+
!$omp task depend(in: array)
166+
!$omp end task
167+
end subroutine
168+
169+
subroutine task_depend_mutable_box(alloc)
170+
integer, allocatable :: alloc
171+
!CHECK: %[[LOAD:.*]] = fir.load %{{.*}} : !fir.ref<!fir.box<!fir.heap<i32>>>
172+
!CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[LOAD]] : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
173+
!CHECK: omp.task depend(taskdependin -> %[[BOX_ADDR]] : !fir.heap<i32>)
174+
!$omp task depend(in: alloc)
175+
!$omp end task
176+
end subroutine
177+
161178
!===============================================================================
162179
! `private` clause
163180
!===============================================================================

0 commit comments

Comments
 (0)