Skip to content

Commit d8e27cd

Browse files
committed
Changes for invoking scan Op
1 parent afcbcae commit d8e27cd

File tree

11 files changed

+127
-45
lines changed

11 files changed

+127
-45
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,19 @@ bool ClauseProcessor::processDistSchedule(
344344
return false;
345345
}
346346

347+
bool ClauseProcessor::processExclusive(
348+
mlir::Location currentLocation,
349+
mlir::omp::ExclusiveClauseOps &result) const {
350+
return findRepeatableClause<omp::clause::Exclusive>(
351+
[&](const omp::clause::Exclusive &clause, const parser::CharBlock &) {
352+
for (const Object &object : clause.v) {
353+
const semantics::Symbol *symbol = object.sym();
354+
mlir::Value symVal = converter.getSymbolAddress(*symbol);
355+
result.exclusiveVars.push_back(symVal);
356+
}
357+
});
358+
}
359+
347360
bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
348361
mlir::omp::FilterClauseOps &result) const {
349362
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
@@ -380,6 +393,19 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
380393
return false;
381394
}
382395

396+
bool ClauseProcessor::processInclusive(
397+
mlir::Location currentLocation,
398+
mlir::omp::InclusiveClauseOps &result) const {
399+
return findRepeatableClause<omp::clause::Inclusive>(
400+
[&](const omp::clause::Inclusive &clause, const parser::CharBlock &) {
401+
for (const Object &object : clause.v) {
402+
const semantics::Symbol *symbol = object.sym();
403+
mlir::Value symVal = converter.getSymbolAddress(*symbol);
404+
result.inclusiveVars.push_back(symVal);
405+
}
406+
});
407+
}
408+
383409
bool ClauseProcessor::processMergeable(
384410
mlir::omp::MergeableClauseOps &result) const {
385411
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
@@ -1135,10 +1161,9 @@ bool ClauseProcessor::processReduction(
11351161
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
11361162
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
11371163
ReductionProcessor rp;
1138-
rp.addDeclareReduction(currentLocation, converter, clause,
1139-
reductionVars, reduceVarByRef,
1140-
reductionDeclSymbols, reductionSyms);
1141-
1164+
rp.addDeclareReduction(
1165+
currentLocation, converter, clause, reductionVars, reduceVarByRef,
1166+
reductionDeclSymbols, reductionSyms, &result.reductionMod);
11421167
// Copy local lists into the output.
11431168
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
11441169
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class ClauseProcessor {
6464
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
6565
bool processDistSchedule(lower::StatementContext &stmtCtx,
6666
mlir::omp::DistScheduleClauseOps &result) const;
67+
bool processExclusive(mlir::Location currentLocation,
68+
mlir::omp::ExclusiveClauseOps &result) const;
6769
bool processFilter(lower::StatementContext &stmtCtx,
6870
mlir::omp::FilterClauseOps &result) const;
6971
bool processFinal(lower::StatementContext &stmtCtx,
@@ -72,6 +74,8 @@ class ClauseProcessor {
7274
mlir::omp::HasDeviceAddrClauseOps &result,
7375
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
7476
bool processHint(mlir::omp::HintClauseOps &result) const;
77+
bool processInclusive(mlir::Location currentLocation,
78+
mlir::omp::InclusiveClauseOps &result) const;
7579
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
7680
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
7781
bool processNumTeams(lower::StatementContext &stmtCtx,

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp,
728728

729729
Exclusive make(const parser::OmpClause::Exclusive &inp,
730730
semantics::SemanticsContext &semaCtx) {
731-
// inp -> empty
732-
llvm_unreachable("Empty: exclusive");
731+
// inp.v -> parser::OmpObjectList
732+
return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
733733
}
734734

735735
Fail make(const parser::OmpClause::Fail &inp,
@@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp,
838838

839839
Inclusive make(const parser::OmpClause::Inclusive &inp,
840840
semantics::SemanticsContext &semaCtx) {
841-
// inp -> empty
842-
llvm_unreachable("Empty: inclusive");
841+
// inp.v -> parser::OmpObjectList
842+
return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
843843
}
844844

845845
Indirect make(const parser::OmpClause::Indirect &inp,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,15 @@ static void genParallelClauses(
15781578
cp.processReduction(loc, clauseOps, reductionSyms);
15791579
}
15801580

1581+
static void genScanClauses(lower::AbstractConverter &converter,
1582+
semantics::SemanticsContext &semaCtx,
1583+
const List<Clause> &clauses, mlir::Location loc,
1584+
mlir::omp::ScanOperands &clauseOps) {
1585+
ClauseProcessor cp(converter, semaCtx, clauses);
1586+
cp.processInclusive(loc, clauseOps);
1587+
cp.processExclusive(loc, clauseOps);
1588+
}
1589+
15811590
static void genSectionsClauses(
15821591
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
15831592
const List<Clause> &clauses, mlir::Location loc,
@@ -1975,6 +1984,16 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19751984
return parallelOp;
19761985
}
19771986

1987+
static mlir::omp::ScanOp
1988+
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1989+
semantics::SemanticsContext &semaCtx, mlir::Location loc,
1990+
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
1991+
mlir::omp::ScanOperands clauseOps;
1992+
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
1993+
return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
1994+
converter.getCurrentLocation(), clauseOps);
1995+
}
1996+
19781997
/// This breaks the normal prototype of the gen*Op functions: adding the
19791998
/// sectionBlocks argument so that the enclosed section constructs can be
19801999
/// lowered here with correct reduction symbol remapping.
@@ -2978,7 +2997,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
29782997
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
29792998
break;
29802999
case llvm::omp::Directive::OMPD_scan:
2981-
TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
3000+
genScanOp(converter, symTable, semaCtx, loc, queue, item);
29823001
break;
29833002
case llvm::omp::Directive::OMPD_section:
29843003
llvm_unreachable("genOMPDispatch: OMPD_section");

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "flang/Parser/tools.h"
2626
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2727
#include "llvm/Support/CommandLine.h"
28+
#include <string>
2829

2930
static llvm::cl::opt<bool> forceByrefReduction(
3031
"force-byref-reduction",
@@ -514,18 +515,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
514515
return false;
515516
}
516517

518+
mlir::omp::ReductionModifier
519+
translateReductionModifier(const ReductionModifier &m) {
520+
switch (m) {
521+
case ReductionModifier::Default:
522+
return mlir::omp::ReductionModifier::defaultmod;
523+
case ReductionModifier::Inscan:
524+
return mlir::omp::ReductionModifier::inscan;
525+
case ReductionModifier::Task:
526+
return mlir::omp::ReductionModifier::task;
527+
}
528+
return mlir::omp::ReductionModifier::defaultmod;
529+
}
530+
517531
void ReductionProcessor::addDeclareReduction(
518532
mlir::Location currentLocation, lower::AbstractConverter &converter,
519533
const omp::clause::Reduction &reduction,
520534
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
521535
llvm::SmallVectorImpl<bool> &reduceVarByRef,
522536
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
523-
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
537+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
538+
mlir::omp::ReductionModifierAttr *reductionMod) {
524539
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
525540

526-
if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
527-
reduction.t))
528-
TODO(currentLocation, "Reduction modifiers are not supported");
541+
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
542+
if (mod.has_value() && (mod.value() != ReductionModifier::Inscan)) {
543+
std::string modStr = "default";
544+
if (mod.value() == ReductionModifier::Task)
545+
modStr = "task";
546+
TODO(currentLocation, "Reduction modifier " + modStr + " is not supported");
547+
}
529548

530549
mlir::omp::DeclareReductionOp decl;
531550
const auto &redOperatorList{
@@ -649,6 +668,11 @@ void ReductionProcessor::addDeclareReduction(
649668
currentLocation, isByRef);
650669
reductionDeclSymbols.push_back(
651670
mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
671+
auto redMod = std::get<std::optional<ReductionModifier>>(reduction.t);
672+
if (redMod.has_value())
673+
*reductionMod = mlir::omp::ReductionModifierAttr::get(
674+
firOpBuilder.getContext(),
675+
translateReductionModifier(redMod.value()));
652676
}
653677
}
654678

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "flang/Parser/parse-tree.h"
2020
#include "flang/Semantics/symbol.h"
2121
#include "flang/Semantics/type.h"
22+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2223
#include "mlir/IR/Location.h"
2324
#include "mlir/IR/Types.h"
2425

@@ -126,7 +127,8 @@ class ReductionProcessor {
126127
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
127128
llvm::SmallVectorImpl<bool> &reduceVarByRef,
128129
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
129-
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
130+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
131+
mlir::omp::ReductionModifierAttr *reductionMod);
130132
};
131133

132134
template <typename FloatOp, typename IntegerOp>
@@ -156,6 +158,8 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
156158
return builder.create<ComplexOp>(loc, op1, op2);
157159
}
158160

161+
using ReductionModifier = omp::clause::Reduction::ReductionModifier;
162+
159163
} // namespace omp
160164
} // namespace lower
161165
} // namespace Fortran

flang/test/Lower/OpenMP/Todo/reduction-inscan.f90

Lines changed: 0 additions & 15 deletions
This file was deleted.

flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90

Lines changed: 0 additions & 14 deletions
This file was deleted.

flang/test/Lower/OpenMP/Todo/reduction-task.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
22
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
33

4-
! CHECK: not yet implemented: Reduction modifiers are not supported
4+
! CHECK: not yet implemented: Reduction modifier task is not supported
55
subroutine reduction_task()
66
integer :: i
77
i = 0

flang/test/Lower/OpenMP/scan.f90

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
2+
3+
subroutine inclusive_scan
4+
implicit none
5+
integer, parameter :: n = 100
6+
integer a(n), b(n)
7+
integer x, k
8+
9+
!CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
10+
!$omp parallel do reduction(inscan, +: x)
11+
do k = 1, n
12+
x = x + a(k)
13+
!CHECK: omp.scan inclusive({{.*}})
14+
!$omp scan inclusive(x)
15+
b(k) = x
16+
end do
17+
end subroutine inclusive_scan
18+
19+
20+
subroutine exclusive_scan
21+
implicit none
22+
integer, parameter :: n = 100
23+
integer a(n), b(n)
24+
integer x, k
25+
26+
!CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
27+
!$omp parallel do reduction(inscan, +: x)
28+
do k = 1, n
29+
x = x + a(k)
30+
!CHECK: omp.scan exclusive({{.*}})
31+
!$omp scan exclusive(x)
32+
b(k) = x
33+
end do
34+
end subroutine exclusive_scan

mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
226226
target.addDynamicallyLegalOp<
227227
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
228228
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
229-
omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
229+
omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp,
230230
omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
231231
omp::YieldOp>([&](Operation *op) {
232232
return typeConverter.isLegal(op->getOperandTypes()) &&
@@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
264264
RegionLessOpConversion<omp::CancelOp>,
265265
RegionLessOpConversion<omp::CriticalDeclareOp>,
266266
RegionLessOpConversion<omp::OrderedOp>,
267+
RegionLessOpConversion<omp::ScanOp>,
267268
RegionLessOpConversion<omp::TargetEnterDataOp>,
268269
RegionLessOpConversion<omp::TargetExitDataOp>,
269270
RegionLessOpConversion<omp::TargetUpdateOp>,

0 commit comments

Comments
 (0)