@@ -621,10 +621,12 @@ class ClauseProcessor {
621
621
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr ,
622
622
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
623
623
*mapSymbols = nullptr ) const ;
624
- bool processReduction (
625
- mlir::Location currentLocation,
626
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
627
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const ;
624
+ bool
625
+ processReduction (mlir::Location currentLocation,
626
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
627
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
628
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
629
+ *reductionSymbols = nullptr ) const ;
628
630
bool processSectionsReduction (mlir::Location currentLocation) const ;
629
631
bool processTo (llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const ;
630
632
bool
@@ -1079,12 +1081,14 @@ class ReductionProcessor {
1079
1081
1080
1082
// / Creates a reduction declaration and associates it with an OpenMP block
1081
1083
// / directive.
1082
- static void addReductionDecl (
1083
- mlir::Location currentLocation,
1084
- Fortran::lower::AbstractConverter &converter,
1085
- const Fortran::parser::OmpReductionClause &reduction,
1086
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1087
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
1084
+ static void
1085
+ addReductionDecl (mlir::Location currentLocation,
1086
+ Fortran::lower::AbstractConverter &converter,
1087
+ const Fortran::parser::OmpReductionClause &reduction,
1088
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1089
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1090
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
1091
+ *reductionSymbols = nullptr ) {
1088
1092
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1089
1093
mlir::omp::ReductionDeclareOp decl;
1090
1094
const auto &redOperator{
@@ -1114,6 +1118,8 @@ class ReductionProcessor {
1114
1118
if (const auto *name{
1115
1119
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1116
1120
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
1121
+ if (reductionSymbols)
1122
+ reductionSymbols->push_back (symbol);
1117
1123
mlir::Value symVal = converter.getSymbolAddress (*symbol);
1118
1124
if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
1119
1125
symVal = declOp.getBase ();
@@ -1148,6 +1154,8 @@ class ReductionProcessor {
1148
1154
if (const auto *name{
1149
1155
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1150
1156
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
1157
+ if (reductionSymbols)
1158
+ reductionSymbols->push_back (symbol);
1151
1159
mlir::Value symVal = converter.getSymbolAddress (*symbol);
1152
1160
if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
1153
1161
symVal = declOp.getBase ();
@@ -1948,13 +1956,16 @@ bool ClauseProcessor::processMap(
1948
1956
bool ClauseProcessor::processReduction (
1949
1957
mlir::Location currentLocation,
1950
1958
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1951
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
1959
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1960
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
1961
+ const {
1952
1962
return findRepeatableClause<ClauseTy::Reduction>(
1953
1963
[&](const ClauseTy::Reduction *reductionClause,
1954
1964
const Fortran::parser::CharBlock &) {
1955
1965
ReductionProcessor rp;
1956
1966
rp.addReductionDecl (currentLocation, converter, reductionClause->v ,
1957
- reductionVars, reductionDeclSymbols);
1967
+ reductionVars, reductionDeclSymbols,
1968
+ reductionSymbols);
1958
1969
});
1959
1970
}
1960
1971
@@ -2304,6 +2315,14 @@ struct OpWithBodyGenInfo {
2304
2315
return *this ;
2305
2316
}
2306
2317
2318
+ OpWithBodyGenInfo &
2319
+ setReductions (llvm::SmallVector<const Fortran::semantics::Symbol *> *value1,
2320
+ llvm::SmallVector<mlir::Type> *value2) {
2321
+ reductionSymbols = value1;
2322
+ reductionTypes = value2;
2323
+ return *this ;
2324
+ }
2325
+
2307
2326
OpWithBodyGenInfo &setGenRegionEntryCb (GenOMPRegionEntryCBFn value) {
2308
2327
genRegionEntryCB = value;
2309
2328
return *this ;
@@ -2323,6 +2342,11 @@ struct OpWithBodyGenInfo {
2323
2342
const Fortran::parser::OmpClauseList *clauses = nullptr ;
2324
2343
// / [in] if provided, processes the construct's data-sharing attributes.
2325
2344
DataSharingProcessor *dsp = nullptr ;
2345
+ // / [in] if provided, list of reduction symbols
2346
+ llvm::SmallVector<const Fortran::semantics::Symbol *> *reductionSymbols =
2347
+ nullptr ;
2348
+ // / [in] if provided, list of reduction types
2349
+ llvm::SmallVector<mlir::Type> *reductionTypes = nullptr ;
2326
2350
// / [in] if provided, emits the op's region entry. Otherwise, an emtpy block
2327
2351
// / is created in the region.
2328
2352
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr ;
@@ -2567,6 +2591,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
2567
2591
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
2568
2592
reductionVars;
2569
2593
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2594
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
2570
2595
2571
2596
ClauseProcessor cp (converter, clauseList);
2572
2597
cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2576,13 +2601,33 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
2576
2601
cp.processDefault ();
2577
2602
cp.processAllocate (allocatorOperands, allocateOperands);
2578
2603
if (!outerCombined)
2579
- cp.processReduction (currentLocation, reductionVars, reductionDeclSymbols);
2604
+ cp.processReduction (currentLocation, reductionVars, reductionDeclSymbols,
2605
+ &reductionSymbols);
2606
+
2607
+ llvm::SmallVector<mlir::Type> reductionTypes;
2608
+ reductionTypes.reserve (reductionVars.size ());
2609
+ llvm::transform (reductionVars, std::back_inserter (reductionTypes),
2610
+ [](mlir::Value v) { return v.getType (); });
2611
+
2612
+ auto reductionCallback = [&](mlir::Operation *op) {
2613
+ llvm::SmallVector<mlir::Location> locs (reductionVars.size (),
2614
+ currentLocation);
2615
+ auto block = converter.getFirOpBuilder ().createBlock (&op->getRegion (0 ), {},
2616
+ reductionTypes, locs);
2617
+ for (auto [arg, prv] :
2618
+ llvm::zip_equal (reductionSymbols, block->getArguments ())) {
2619
+ converter.bindSymbol (*arg, prv);
2620
+ }
2621
+ return reductionSymbols;
2622
+ };
2580
2623
2581
2624
return genOpWithBody<mlir::omp::ParallelOp>(
2582
2625
OpWithBodyGenInfo (converter, currentLocation, eval)
2583
2626
.setGenNested (genNested)
2584
2627
.setOuterCombined (outerCombined)
2585
- .setClauses (&clauseList),
2628
+ .setClauses (&clauseList)
2629
+ .setReductions (&reductionSymbols, &reductionTypes)
2630
+ .setGenRegionEntryCb (reductionCallback),
2586
2631
/* resultTypes=*/ mlir::TypeRange (), ifClauseOperand,
2587
2632
numThreadsClauseOperand, allocateOperands, allocatorOperands,
2588
2633
reductionVars,
@@ -3634,10 +3679,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
3634
3679
break ;
3635
3680
}
3636
3681
3637
- if (singleDirective) {
3638
- genOpenMPReduction (converter, beginClauseList);
3682
+ if (singleDirective)
3639
3683
return ;
3640
- }
3641
3684
3642
3685
// Codegen for combined directives
3643
3686
bool combinedDirective = false ;
@@ -3673,7 +3716,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
3673
3716
" )" );
3674
3717
3675
3718
genNestedEvaluations (converter, eval);
3676
- genOpenMPReduction (converter, beginClauseList);
3677
3719
}
3678
3720
3679
3721
static void
0 commit comments