Skip to content

Commit 9e7d529

Browse files
authored
[Flang][OpenMP]Support for lowering task_reduction and in_reduction to MLIR (llvm#111155)
This patch, - Added support for lowering of task_reduction to MLIR - Added support for lowering of in_reduction to MLIR - Fixed incorrect DSA handling for variables in the presence of 'in_reduction' clause.
1 parent 22d10f0 commit 9e7d529

13 files changed

+299
-64
lines changed

flang/include/flang/Semantics/symbol.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -754,12 +754,12 @@ class Symbol {
754754
// OpenMP data-copying attribute
755755
OmpCopyIn, OmpCopyPrivate,
756756
// OpenMP miscellaneous flags
757-
OmpCommonBlock, OmpReduction, OmpAligned, OmpNontemporal, OmpAllocate,
758-
OmpDeclarativeAllocateDirective, OmpExecutableAllocateDirective,
759-
OmpDeclareSimd, OmpDeclareTarget, OmpThreadprivate, OmpDeclareReduction,
760-
OmpFlushed, OmpCriticalLock, OmpIfSpecified, OmpNone, OmpPreDetermined,
761-
OmpImplicit, OmpDependObject, OmpInclusiveScan, OmpExclusiveScan,
762-
OmpInScanReduction);
757+
OmpCommonBlock, OmpReduction, OmpInReduction, OmpAligned, OmpNontemporal,
758+
OmpAllocate, OmpDeclarativeAllocateDirective,
759+
OmpExecutableAllocateDirective, OmpDeclareSimd, OmpDeclareTarget,
760+
OmpThreadprivate, OmpDeclareReduction, OmpFlushed, OmpCriticalLock,
761+
OmpIfSpecified, OmpNone, OmpPreDetermined, OmpImplicit, OmpDependObject,
762+
OmpInclusiveScan, OmpExclusiveScan, OmpInScanReduction);
763763
using Flags = common::EnumSet<Flag, Flag_enumSize>;
764764

765765
const Scope &owner() const { return *owner_; }

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,29 @@ bool ClauseProcessor::processIf(
983983
});
984984
return found;
985985
}
986+
bool ClauseProcessor::processInReduction(
987+
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
988+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
989+
return findRepeatableClause<omp::clause::InReduction>(
990+
[&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
991+
llvm::SmallVector<mlir::Value> inReductionVars;
992+
llvm::SmallVector<bool> inReduceVarByRef;
993+
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
994+
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
995+
ReductionProcessor rp;
996+
rp.processReductionArguments<omp::clause::InReduction>(
997+
currentLocation, converter, clause, inReductionVars,
998+
inReduceVarByRef, inReductionDeclSymbols, inReductionSyms);
999+
1000+
// Copy local lists into the output.
1001+
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
1002+
llvm::copy(inReduceVarByRef,
1003+
std::back_inserter(result.inReductionByref));
1004+
llvm::copy(inReductionDeclSymbols,
1005+
std::back_inserter(result.inReductionSyms));
1006+
llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
1007+
});
1008+
}
9861009

9871010
bool ClauseProcessor::processIsDevicePtr(
9881011
mlir::omp::IsDevicePtrClauseOps &result,
@@ -1257,9 +1280,9 @@ bool ClauseProcessor::processReduction(
12571280
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
12581281
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
12591282
ReductionProcessor rp;
1260-
rp.processReductionArguments(
1283+
rp.processReductionArguments<omp::clause::Reduction>(
12611284
currentLocation, converter, clause, reductionVars, reduceVarByRef,
1262-
reductionDeclSymbols, reductionSyms, result.reductionMod);
1285+
reductionDeclSymbols, reductionSyms, &result.reductionMod);
12631286
// Copy local lists into the output.
12641287
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
12651288
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
@@ -1269,6 +1292,30 @@ bool ClauseProcessor::processReduction(
12691292
});
12701293
}
12711294

1295+
bool ClauseProcessor::processTaskReduction(
1296+
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
1297+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
1298+
return findRepeatableClause<omp::clause::TaskReduction>(
1299+
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
1300+
llvm::SmallVector<mlir::Value> taskReductionVars;
1301+
llvm::SmallVector<bool> TaskReduceVarByRef;
1302+
llvm::SmallVector<mlir::Attribute> TaskReductionDeclSymbols;
1303+
llvm::SmallVector<const semantics::Symbol *> TaskReductionSyms;
1304+
ReductionProcessor rp;
1305+
rp.processReductionArguments<omp::clause::TaskReduction>(
1306+
currentLocation, converter, clause, taskReductionVars,
1307+
TaskReduceVarByRef, TaskReductionDeclSymbols, TaskReductionSyms);
1308+
// Copy local lists into the output.
1309+
llvm::copy(taskReductionVars,
1310+
std::back_inserter(result.taskReductionVars));
1311+
llvm::copy(TaskReduceVarByRef,
1312+
std::back_inserter(result.taskReductionByref));
1313+
llvm::copy(TaskReductionDeclSymbols,
1314+
std::back_inserter(result.taskReductionSyms));
1315+
llvm::copy(TaskReductionSyms, std::back_inserter(outReductionSyms));
1316+
});
1317+
}
1318+
12721319
bool ClauseProcessor::processTo(
12731320
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
12741321
return findRepeatableClause<omp::clause::To>(

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ class ClauseProcessor {
112112
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
113113
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
114114
mlir::omp::IfClauseOps &result) const;
115+
bool processInReduction(
116+
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
117+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
115118
bool processIsDevicePtr(
116119
mlir::omp::IsDevicePtrClauseOps &result,
117120
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
@@ -133,6 +136,9 @@ class ClauseProcessor {
133136
bool processReduction(
134137
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
135138
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
139+
bool processTaskReduction(
140+
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
141+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
136142
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
137143
bool processUseDeviceAddr(
138144
lower::StatementContext &stmtCtx,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,34 +1774,34 @@ static void genTargetEnterExitUpdateDataClauses(
17741774
cp.processNowait(clauseOps);
17751775
}
17761776

1777-
static void genTaskClauses(lower::AbstractConverter &converter,
1778-
semantics::SemanticsContext &semaCtx,
1779-
lower::SymMap &symTable,
1780-
lower::StatementContext &stmtCtx,
1781-
const List<Clause> &clauses, mlir::Location loc,
1782-
mlir::omp::TaskOperands &clauseOps) {
1777+
static void genTaskClauses(
1778+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1779+
lower::SymMap &symTable, lower::StatementContext &stmtCtx,
1780+
const List<Clause> &clauses, mlir::Location loc,
1781+
mlir::omp::TaskOperands &clauseOps,
1782+
llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
17831783
ClauseProcessor cp(converter, semaCtx, clauses);
17841784
cp.processAllocate(clauseOps);
17851785
cp.processDepend(symTable, stmtCtx, clauseOps);
17861786
cp.processFinal(stmtCtx, clauseOps);
17871787
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
1788+
cp.processInReduction(loc, clauseOps, inReductionSyms);
17881789
cp.processMergeable(clauseOps);
17891790
cp.processPriority(stmtCtx, clauseOps);
17901791
cp.processUntied(clauseOps);
17911792
cp.processDetach(clauseOps);
17921793

1793-
cp.processTODO<clause::Affinity, clause::InReduction>(
1794-
loc, llvm::omp::Directive::OMPD_task);
1794+
cp.processTODO<clause::Affinity>(loc, llvm::omp::Directive::OMPD_task);
17951795
}
17961796

1797-
static void genTaskgroupClauses(lower::AbstractConverter &converter,
1798-
semantics::SemanticsContext &semaCtx,
1799-
const List<Clause> &clauses, mlir::Location loc,
1800-
mlir::omp::TaskgroupOperands &clauseOps) {
1797+
static void genTaskgroupClauses(
1798+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1799+
const List<Clause> &clauses, mlir::Location loc,
1800+
mlir::omp::TaskgroupOperands &clauseOps,
1801+
llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
18011802
ClauseProcessor cp(converter, semaCtx, clauses);
18021803
cp.processAllocate(clauseOps);
1803-
cp.processTODO<clause::TaskReduction>(loc,
1804-
llvm::omp::Directive::OMPD_taskgroup);
1804+
cp.processTaskReduction(loc, clauseOps, taskReductionSyms);
18051805
}
18061806

18071807
static void genTaskloopClauses(lower::AbstractConverter &converter,
@@ -2496,8 +2496,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24962496
mlir::Location loc, const ConstructQueue &queue,
24972497
ConstructQueue::const_iterator item) {
24982498
mlir::omp::TaskOperands clauseOps;
2499+
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
24992500
genTaskClauses(converter, semaCtx, symTable, stmtCtx, item->clauses, loc,
2500-
clauseOps);
2501+
clauseOps, inReductionSyms);
25012502

25022503
if (!enableDelayedPrivatization)
25032504
return genOpWithBody<mlir::omp::TaskOp>(
@@ -2514,6 +2515,8 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25142515
EntryBlockArgs taskArgs;
25152516
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
25162517
taskArgs.priv.vars = clauseOps.privateVars;
2518+
taskArgs.inReduction.syms = inReductionSyms;
2519+
taskArgs.inReduction.vars = clauseOps.inReductionVars;
25172520

25182521
return genOpWithBody<mlir::omp::TaskOp>(
25192522
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
@@ -2531,12 +2534,19 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25312534
const ConstructQueue &queue,
25322535
ConstructQueue::const_iterator item) {
25332536
mlir::omp::TaskgroupOperands clauseOps;
2534-
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
2537+
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
2538+
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
2539+
taskReductionSyms);
2540+
2541+
EntryBlockArgs taskgroupArgs;
2542+
taskgroupArgs.taskReduction.syms = taskReductionSyms;
2543+
taskgroupArgs.taskReduction.vars = clauseOps.taskReductionVars;
25352544

25362545
return genOpWithBody<mlir::omp::TaskgroupOp>(
25372546
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
25382547
llvm::omp::Directive::OMPD_taskgroup)
2539-
.setClauses(&item->clauses),
2548+
.setClauses(&item->clauses)
2549+
.setEntryBlockArgs(&taskgroupArgs),
25402550
queue, item, clauseOps);
25412551
}
25422552

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "flang/Parser/tools.h"
2626
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2727
#include "llvm/Support/CommandLine.h"
28+
#include <type_traits>
2829

2930
static llvm::cl::opt<bool> forceByrefReduction(
3031
"force-byref-reduction",
@@ -38,6 +39,37 @@ namespace Fortran {
3839
namespace lower {
3940
namespace omp {
4041

42+
// explicit template declarations
43+
template void
44+
ReductionProcessor::processReductionArguments<omp::clause::Reduction>(
45+
mlir::Location currentLocation, lower::AbstractConverter &converter,
46+
const omp::clause::Reduction &reduction,
47+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
48+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
49+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
50+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
51+
mlir::omp::ReductionModifierAttr *reductionMod);
52+
53+
template void
54+
ReductionProcessor::processReductionArguments<omp::clause::TaskReduction>(
55+
mlir::Location currentLocation, lower::AbstractConverter &converter,
56+
const omp::clause::TaskReduction &reduction,
57+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
58+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
59+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
60+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
61+
mlir::omp::ReductionModifierAttr *reductionMod);
62+
63+
template void
64+
ReductionProcessor::processReductionArguments<omp::clause::InReduction>(
65+
mlir::Location currentLocation, lower::AbstractConverter &converter,
66+
const omp::clause::InReduction &reduction,
67+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
68+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
69+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
70+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
71+
mlir::omp::ReductionModifierAttr *reductionMod);
72+
4173
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
4274
const omp::clause::ProcedureDesignator &pd) {
4375
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
@@ -538,28 +570,30 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
538570
return mlir::omp::ReductionModifier::defaultmod;
539571
}
540572

573+
template <class T>
541574
void ReductionProcessor::processReductionArguments(
542575
mlir::Location currentLocation, lower::AbstractConverter &converter,
543-
const omp::clause::Reduction &reduction,
544-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
576+
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
545577
llvm::SmallVectorImpl<bool> &reduceVarByRef,
546578
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
547579
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
548-
mlir::omp::ReductionModifierAttr &reductionMod) {
580+
mlir::omp::ReductionModifierAttr *reductionMod) {
549581
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
550582

551-
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
552-
if (mod.has_value()) {
553-
if (mod.value() == ReductionModifier::Task)
554-
TODO(currentLocation, "Reduction modifier `task` is not supported");
555-
else
556-
reductionMod = mlir::omp::ReductionModifierAttr::get(
557-
firOpBuilder.getContext(), translateReductionModifier(mod.value()));
583+
if constexpr (std::is_same_v<T, omp::clause::Reduction>) {
584+
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
585+
if (mod.has_value()) {
586+
if (mod.value() == ReductionModifier::Task)
587+
TODO(currentLocation, "Reduction modifier `task` is not supported");
588+
else
589+
*reductionMod = mlir::omp::ReductionModifierAttr::get(
590+
firOpBuilder.getContext(), translateReductionModifier(mod.value()));
591+
}
558592
}
559593

560594
mlir::omp::DeclareReductionOp decl;
561595
const auto &redOperatorList{
562-
std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
596+
std::get<typename T::ReductionIdentifiers>(reduction.t)};
563597
assert(redOperatorList.size() == 1 && "Expecting single operator");
564598
const auto &redOperator = redOperatorList.front();
565599
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,14 @@ class ReductionProcessor {
121121

122122
/// Creates a reduction declaration and associates it with an OpenMP block
123123
/// directive.
124+
template <class T>
124125
static void processReductionArguments(
125126
mlir::Location currentLocation, lower::AbstractConverter &converter,
126-
const omp::clause::Reduction &reduction,
127-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
127+
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
128128
llvm::SmallVectorImpl<bool> &reduceVarByRef,
129129
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
130130
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
131-
mlir::omp::ReductionModifierAttr &reductionMod);
131+
mlir::omp::ReductionModifierAttr *reductionMod = nullptr);
132132
};
133133

134134
template <typename FloatOp, typename IntegerOp>

flang/lib/Semantics/resolve-directives.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,12 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
530530
return false;
531531
}
532532

533+
bool Pre(const parser::OmpInReductionClause &x) {
534+
auto &objects{std::get<parser::OmpObjectList>(x.t)};
535+
ResolveOmpObjectList(objects, Symbol::Flag::OmpInReduction);
536+
return false;
537+
}
538+
533539
bool Pre(const parser::OmpClause::Reduction &x) {
534540
const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
535541
ResolveOmpObjectList(objList, Symbol::Flag::OmpReduction);

flang/test/Lower/OpenMP/Todo/task-inreduction.f90

Lines changed: 0 additions & 15 deletions
This file was deleted.

flang/test/Lower/OpenMP/Todo/taskgroup-task-reduction.f90

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
3+
4+
!CHECK-LABEL: omp.declare_reduction
5+
!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
6+
!CHECK: ^bb0(%{{.*}}: i32):
7+
!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
8+
!CHECK: omp.yield(%[[C0_1]] : i32)
9+
!CHECK: } combiner {
10+
!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
11+
!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
12+
!CHECK: omp.yield(%[[RES]] : i32)
13+
!CHECK: }
14+
15+
!CHECK-LABEL: func.func @_QPomp_task_in_reduction() {
16+
! [...]
17+
!CHECK: omp.task in_reduction(@[[RED_I32_NAME]] %[[VAL_1:.*]]#0 -> %[[ARG0]] : !fir.ref<i32>) {
18+
!CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[ARG0]]
19+
!CHECK-SAME: {uniq_name = "_QFomp_task_in_reductionEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
20+
!CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<i32>
21+
!CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
22+
!CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
23+
!CHECK: hlfir.assign %[[VAL_7]] to %[[VAL_4]]#0 : i32, !fir.ref<i32>
24+
!CHECK: omp.terminator
25+
!CHECK: }
26+
!CHECK: return
27+
!CHECK: }
28+
29+
subroutine omp_task_in_reduction()
30+
integer i
31+
i = 0
32+
!$omp task in_reduction(+:i)
33+
i = i + 1
34+
!$omp end task
35+
end subroutine omp_task_in_reduction

0 commit comments

Comments
 (0)