Skip to content

Commit ccd92ec

Browse files
authored
[flang][openmp] Changes for invoking scan Op (#123254)
1 parent 290a0d8 commit ccd92ec

File tree

11 files changed

+128
-47
lines changed

11 files changed

+128
-47
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,20 @@ bool ClauseProcessor::processDistSchedule(
344344
return false;
345345
}
346346

347+
bool ClauseProcessor::processExclusive(
348+
mlir::Location currentLocation,
349+
mlir::omp::ExclusiveClauseOps &result) const {
350+
if (auto *clause = findUniqueClause<omp::clause::Exclusive>()) {
351+
for (const Object &object : clause->v) {
352+
const semantics::Symbol *symbol = object.sym();
353+
mlir::Value symVal = converter.getSymbolAddress(*symbol);
354+
result.exclusiveVars.push_back(symVal);
355+
}
356+
return true;
357+
}
358+
return false;
359+
}
360+
347361
bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
348362
mlir::omp::FilterClauseOps &result) const {
349363
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
@@ -380,6 +394,20 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
380394
return false;
381395
}
382396

397+
bool ClauseProcessor::processInclusive(
398+
mlir::Location currentLocation,
399+
mlir::omp::InclusiveClauseOps &result) const {
400+
if (auto *clause = findUniqueClause<omp::clause::Inclusive>()) {
401+
for (const Object &object : clause->v) {
402+
const semantics::Symbol *symbol = object.sym();
403+
mlir::Value symVal = converter.getSymbolAddress(*symbol);
404+
result.inclusiveVars.push_back(symVal);
405+
}
406+
return true;
407+
}
408+
return false;
409+
}
410+
383411
bool ClauseProcessor::processMergeable(
384412
mlir::omp::MergeableClauseOps &result) const {
385413
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
@@ -1135,10 +1163,9 @@ bool ClauseProcessor::processReduction(
11351163
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
11361164
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
11371165
ReductionProcessor rp;
1138-
rp.addDeclareReduction(currentLocation, converter, clause,
1139-
reductionVars, reduceVarByRef,
1140-
reductionDeclSymbols, reductionSyms);
1141-
1166+
rp.processReductionArguments(
1167+
currentLocation, converter, clause, reductionVars, reduceVarByRef,
1168+
reductionDeclSymbols, reductionSyms, result.reductionMod);
11421169
// Copy local lists into the output.
11431170
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
11441171
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class ClauseProcessor {
6464
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
6565
bool processDistSchedule(lower::StatementContext &stmtCtx,
6666
mlir::omp::DistScheduleClauseOps &result) const;
67+
bool processExclusive(mlir::Location currentLocation,
68+
mlir::omp::ExclusiveClauseOps &result) const;
6769
bool processFilter(lower::StatementContext &stmtCtx,
6870
mlir::omp::FilterClauseOps &result) const;
6971
bool processFinal(lower::StatementContext &stmtCtx,
@@ -72,6 +74,8 @@ class ClauseProcessor {
7274
mlir::omp::HasDeviceAddrClauseOps &result,
7375
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
7476
bool processHint(mlir::omp::HintClauseOps &result) const;
77+
bool processInclusive(mlir::Location currentLocation,
78+
mlir::omp::InclusiveClauseOps &result) const;
7579
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
7680
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
7781
bool processNumTeams(lower::StatementContext &stmtCtx,

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,8 @@ Enter make(const parser::OmpClause::Enter &inp,
736736

737737
Exclusive make(const parser::OmpClause::Exclusive &inp,
738738
semantics::SemanticsContext &semaCtx) {
739-
// inp -> empty
740-
llvm_unreachable("Empty: exclusive");
739+
// inp.v -> parser::OmpObjectList
740+
return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
741741
}
742742

743743
Fail make(const parser::OmpClause::Fail &inp,
@@ -846,8 +846,8 @@ If make(const parser::OmpClause::If &inp,
846846

847847
Inclusive make(const parser::OmpClause::Inclusive &inp,
848848
semantics::SemanticsContext &semaCtx) {
849-
// inp -> empty
850-
llvm_unreachable("Empty: inclusive");
849+
// inp.v -> parser::OmpObjectList
850+
return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
851851
}
852852

853853
Indirect make(const parser::OmpClause::Indirect &inp,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1584,6 +1584,15 @@ static void genParallelClauses(
15841584
cp.processReduction(loc, clauseOps, reductionSyms);
15851585
}
15861586

1587+
static void genScanClauses(lower::AbstractConverter &converter,
1588+
semantics::SemanticsContext &semaCtx,
1589+
const List<Clause> &clauses, mlir::Location loc,
1590+
mlir::omp::ScanOperands &clauseOps) {
1591+
ClauseProcessor cp(converter, semaCtx, clauses);
1592+
cp.processInclusive(loc, clauseOps);
1593+
cp.processExclusive(loc, clauseOps);
1594+
}
1595+
15871596
static void genSectionsClauses(
15881597
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
15891598
const List<Clause> &clauses, mlir::Location loc,
@@ -1981,6 +1990,16 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19811990
return parallelOp;
19821991
}
19831992

1993+
static mlir::omp::ScanOp
1994+
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1995+
semantics::SemanticsContext &semaCtx, mlir::Location loc,
1996+
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
1997+
mlir::omp::ScanOperands clauseOps;
1998+
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
1999+
return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
2000+
converter.getCurrentLocation(), clauseOps);
2001+
}
2002+
19842003
/// This breaks the normal prototype of the gen*Op functions: adding the
19852004
/// sectionBlocks argument so that the enclosed section constructs can be
19862005
/// lowered here with correct reduction symbol remapping.
@@ -2990,7 +3009,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
29903009
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
29913010
break;
29923011
case llvm::omp::Directive::OMPD_scan:
2993-
TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
3012+
genScanOp(converter, symTable, semaCtx, loc, queue, item);
29943013
break;
29953014
case llvm::omp::Directive::OMPD_section:
29963015
llvm_unreachable("genOMPDispatch: OMPD_section");

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ static llvm::cl::opt<bool> forceByrefReduction(
3131
llvm::cl::desc("Pass all reduction arguments by reference"),
3232
llvm::cl::Hidden);
3333

34+
using ReductionModifier =
35+
Fortran::lower::omp::clause::Reduction::ReductionModifier;
36+
3437
namespace Fortran {
3538
namespace lower {
3639
namespace omp {
@@ -518,18 +521,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
518521
return false;
519522
}
520523

521-
void ReductionProcessor::addDeclareReduction(
524+
mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
525+
switch (mod) {
526+
case ReductionModifier::Default:
527+
return mlir::omp::ReductionModifier::defaultmod;
528+
case ReductionModifier::Inscan:
529+
return mlir::omp::ReductionModifier::inscan;
530+
case ReductionModifier::Task:
531+
return mlir::omp::ReductionModifier::task;
532+
}
533+
return mlir::omp::ReductionModifier::defaultmod;
534+
}
535+
536+
void ReductionProcessor::processReductionArguments(
522537
mlir::Location currentLocation, lower::AbstractConverter &converter,
523538
const omp::clause::Reduction &reduction,
524539
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
525540
llvm::SmallVectorImpl<bool> &reduceVarByRef,
526541
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
527-
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
542+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
543+
mlir::omp::ReductionModifierAttr &reductionMod) {
528544
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
529545

530-
if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
531-
reduction.t))
532-
TODO(currentLocation, "Reduction modifiers are not supported");
546+
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
547+
if (mod.has_value()) {
548+
if (mod.value() == ReductionModifier::Task)
549+
TODO(currentLocation, "Reduction modifier `task` is not supported");
550+
else
551+
reductionMod = mlir::omp::ReductionModifierAttr::get(
552+
firOpBuilder.getContext(), translateReductionModifier(mod.value()));
553+
}
533554

534555
mlir::omp::DeclareReductionOp decl;
535556
const auto &redOperatorList{

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "flang/Parser/parse-tree.h"
2020
#include "flang/Semantics/symbol.h"
2121
#include "flang/Semantics/type.h"
22+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2223
#include "mlir/IR/Location.h"
2324
#include "mlir/IR/Types.h"
2425

@@ -120,13 +121,14 @@ class ReductionProcessor {
120121

121122
/// Creates a reduction declaration and associates it with an OpenMP block
122123
/// directive.
123-
static void addDeclareReduction(
124+
static void processReductionArguments(
124125
mlir::Location currentLocation, lower::AbstractConverter &converter,
125126
const omp::clause::Reduction &reduction,
126127
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
127128
llvm::SmallVectorImpl<bool> &reduceVarByRef,
128129
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
129-
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
130+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
131+
mlir::omp::ReductionModifierAttr &reductionMod);
130132
};
131133

132134
template <typename FloatOp, typename IntegerOp>

flang/test/Lower/OpenMP/Todo/reduction-inscan.f90

Lines changed: 0 additions & 15 deletions
This file was deleted.

flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90

Lines changed: 0 additions & 14 deletions
This file was deleted.

flang/test/Lower/OpenMP/Todo/reduction-task.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
22
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
33

4-
! CHECK: not yet implemented: Reduction modifiers are not supported
4+
! CHECK: not yet implemented: Reduction modifier `task` is not supported
55
subroutine reduction_task()
66
integer :: i
77
i = 0

flang/test/Lower/OpenMP/scan.f90

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
! RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
3+
4+
! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_1:.*]] : {{.*}}) {
5+
! CHECK: %[[RED_DECL_1:.*]]:2 = hlfir.declare %[[RED_ARG_1]]
6+
! CHECK: omp.scan inclusive(%[[RED_DECL_1]]#1 : {{.*}})
7+
8+
subroutine inclusive_scan(a, b, n)
9+
implicit none
10+
integer a(:), b(:)
11+
integer x, k, n
12+
13+
!$omp parallel do reduction(inscan, +: x)
14+
do k = 1, n
15+
x = x + a(k)
16+
!$omp scan inclusive(x)
17+
b(k) = x
18+
end do
19+
end subroutine inclusive_scan
20+
21+
22+
! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_2:.*]] : {{.*}}) {
23+
! CHECK: %[[RED_DECL_2:.*]]:2 = hlfir.declare %[[RED_ARG_2]]
24+
! CHECK: omp.scan exclusive(%[[RED_DECL_2]]#1 : {{.*}})
25+
subroutine exclusive_scan(a, b, n)
26+
implicit none
27+
integer a(:), b(:)
28+
integer x, k, n
29+
30+
!$omp parallel do reduction(inscan, +: x)
31+
do k = 1, n
32+
x = x + a(k)
33+
!$omp scan exclusive(x)
34+
b(k) = x
35+
end do
36+
end subroutine exclusive_scan

mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
226226
target.addDynamicallyLegalOp<
227227
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
228228
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
229-
omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
229+
omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp,
230230
omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
231231
omp::YieldOp>([&](Operation *op) {
232232
return typeConverter.isLegal(op->getOperandTypes()) &&
@@ -274,6 +274,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
274274
RegionLessOpConversion<omp::CancelOp>,
275275
RegionLessOpConversion<omp::CriticalDeclareOp>,
276276
RegionLessOpConversion<omp::OrderedOp>,
277+
RegionLessOpConversion<omp::ScanOp>,
277278
RegionLessOpConversion<omp::TargetEnterDataOp>,
278279
RegionLessOpConversion<omp::TargetExitDataOp>,
279280
RegionLessOpConversion<omp::TargetUpdateOp>,

0 commit comments

Comments
 (0)