@@ -621,10 +621,12 @@ class ClauseProcessor {
621
621
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr ,
622
622
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
623
623
*mapSymbols = nullptr ) const ;
624
- bool processReduction (
625
- mlir::Location currentLocation,
626
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
627
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const ;
624
+ bool
625
+ processReduction (mlir::Location currentLocation,
626
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
627
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
628
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
629
+ *reductionSymbols = nullptr ) const ;
628
630
bool processSectionsReduction (mlir::Location currentLocation) const ;
629
631
bool processTo (llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const ;
630
632
bool
@@ -1075,12 +1077,14 @@ class ReductionProcessor {
1075
1077
1076
1078
// / Creates a reduction declaration and associates it with an OpenMP block
1077
1079
// / directive.
1078
- static void addReductionDecl (
1079
- mlir::Location currentLocation,
1080
- Fortran::lower::AbstractConverter &converter,
1081
- const Fortran::parser::OmpReductionClause &reduction,
1082
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1083
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
1080
+ static void
1081
+ addReductionDecl (mlir::Location currentLocation,
1082
+ Fortran::lower::AbstractConverter &converter,
1083
+ const Fortran::parser::OmpReductionClause &reduction,
1084
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1085
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1086
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
1087
+ *reductionSymbols = nullptr ) {
1084
1088
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1085
1089
mlir::omp::ReductionDeclareOp decl;
1086
1090
const auto &redOperator{
@@ -1110,6 +1114,8 @@ class ReductionProcessor {
1110
1114
if (const auto *name{
1111
1115
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1112
1116
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
1117
+ if (reductionSymbols)
1118
+ reductionSymbols->push_back (symbol);
1113
1119
mlir::Value symVal = converter.getSymbolAddress (*symbol);
1114
1120
if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
1115
1121
symVal = declOp.getBase ();
@@ -1142,6 +1148,8 @@ class ReductionProcessor {
1142
1148
if (const auto *name{
1143
1149
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1144
1150
if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
1151
+ if (reductionSymbols)
1152
+ reductionSymbols->push_back (symbol);
1145
1153
mlir::Value symVal = converter.getSymbolAddress (*symbol);
1146
1154
if (auto declOp = symVal.getDefiningOp <hlfir::DeclareOp>())
1147
1155
symVal = declOp.getBase ();
@@ -1935,13 +1943,16 @@ bool ClauseProcessor::processMap(
1935
1943
bool ClauseProcessor::processReduction (
1936
1944
mlir::Location currentLocation,
1937
1945
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
1938
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
1946
+ llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
1947
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
1948
+ const {
1939
1949
return findRepeatableClause<ClauseTy::Reduction>(
1940
1950
[&](const ClauseTy::Reduction *reductionClause,
1941
1951
const Fortran::parser::CharBlock &) {
1942
1952
ReductionProcessor rp;
1943
1953
rp.addReductionDecl (currentLocation, converter, reductionClause->v ,
1944
- reductionVars, reductionDeclSymbols);
1954
+ reductionVars, reductionDeclSymbols,
1955
+ reductionSymbols);
1945
1956
});
1946
1957
}
1947
1958
@@ -2250,8 +2261,11 @@ static void createBodyOfOp(
2250
2261
Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
2251
2262
Fortran::lower::pft::Evaluation &eval, bool genNested,
2252
2263
const Fortran::parser::OmpClauseList *clauses = nullptr ,
2253
- const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
2254
- bool outerCombined = false , DataSharingProcessor *dsp = nullptr ) {
2264
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs = {},
2265
+ bool outerCombined = false , DataSharingProcessor *dsp = nullptr ,
2266
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs =
2267
+ {},
2268
+ const llvm::SmallVector<mlir::Type> &reductionTypes = {}) {
2255
2269
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2256
2270
2257
2271
auto insertMarker = [](fir::FirOpBuilder &builder) {
@@ -2264,24 +2278,32 @@ static void createBodyOfOp(
2264
2278
// argument. Also update the symbol's address with the mlir argument value.
2265
2279
// e.g. For loops the argument is the induction variable. And all further
2266
2280
// uses of the induction variable should use this mlir value.
2267
- if (args .size ()) {
2281
+ if (loopArgs .size ()) {
2268
2282
std::size_t loopVarTypeSize = 0 ;
2269
- for (const Fortran::semantics::Symbol *arg : args )
2283
+ for (const Fortran::semantics::Symbol *arg : loopArgs )
2270
2284
loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
2271
2285
mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
2272
- llvm::SmallVector<mlir::Type> tiv (args .size (), loopVarType);
2273
- llvm::SmallVector<mlir::Location> locs (args .size (), loc);
2286
+ llvm::SmallVector<mlir::Type> tiv (loopArgs .size (), loopVarType);
2287
+ llvm::SmallVector<mlir::Location> locs (loopArgs .size (), loc);
2274
2288
firOpBuilder.createBlock (&op.getRegion (), {}, tiv, locs);
2275
2289
// The argument is not currently in memory, so make a temporary for the
2276
2290
// argument, and store it there, then bind that location to the argument.
2277
2291
mlir::Operation *storeOp = nullptr ;
2278
- for (auto [argIndex, argSymbol] : llvm::enumerate (args )) {
2292
+ for (auto [argIndex, argSymbol] : llvm::enumerate (loopArgs )) {
2279
2293
mlir::Value indexVal =
2280
2294
fir::getBase (op.getRegion ().front ().getArgument (argIndex));
2281
2295
storeOp =
2282
2296
createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
2283
2297
}
2284
2298
firOpBuilder.setInsertionPointAfter (storeOp);
2299
+ } else if (reductionArgs.size ()) {
2300
+ llvm::SmallVector<mlir::Location> locs (reductionArgs.size (), loc);
2301
+ auto block =
2302
+ firOpBuilder.createBlock (&op.getRegion (), {}, reductionTypes, locs);
2303
+ for (auto [arg, prv] :
2304
+ llvm::zip_equal (reductionArgs, block->getArguments ())) {
2305
+ converter.bindSymbol (*arg, prv);
2306
+ }
2285
2307
} else {
2286
2308
firOpBuilder.createBlock (&op.getRegion ());
2287
2309
}
@@ -2382,8 +2404,8 @@ static void createBodyOfOp(
2382
2404
assert (tempDsp.has_value ());
2383
2405
tempDsp->processStep2 (op, isLoop);
2384
2406
} else {
2385
- if (isLoop && args .size () > 0 )
2386
- dsp->setLoopIV (converter.getSymbolAddress (*args [0 ]));
2407
+ if (isLoop && loopArgs .size () > 0 )
2408
+ dsp->setLoopIV (converter.getSymbolAddress (*loopArgs [0 ]));
2387
2409
dsp->processStep2 (op, isLoop);
2388
2410
}
2389
2411
}
@@ -2468,7 +2490,8 @@ static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
2468
2490
currentLocation, std::forward<Args>(args)...);
2469
2491
createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
2470
2492
clauseList,
2471
- /* args=*/ {}, outerCombined);
2493
+ /* loopArgs=*/ {}, outerCombined, /* dsp=*/ nullptr ,
2494
+ /* reductionArgs=*/ {}, /* reductionTypes=*/ {});
2472
2495
return op;
2473
2496
}
2474
2497
@@ -2505,6 +2528,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
2505
2528
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
2506
2529
reductionVars;
2507
2530
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
2531
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
2508
2532
2509
2533
ClauseProcessor cp (converter, clauseList);
2510
2534
cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2514,18 +2538,29 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
2514
2538
cp.processDefault ();
2515
2539
cp.processAllocate (allocatorOperands, allocateOperands);
2516
2540
if (!outerCombined)
2517
- cp.processReduction (currentLocation, reductionVars, reductionDeclSymbols);
2541
+ cp.processReduction (currentLocation, reductionVars, reductionDeclSymbols,
2542
+ &reductionSymbols);
2518
2543
2519
- return genOpWithBody<mlir::omp::ParallelOp>(
2520
- converter, eval, genNested, currentLocation, outerCombined, &clauseList,
2521
- /* resultTypes=*/ mlir::TypeRange (), ifClauseOperand,
2544
+ auto op = converter.getFirOpBuilder ().create <mlir::omp::ParallelOp>(
2545
+ currentLocation, mlir::TypeRange (), ifClauseOperand,
2522
2546
numThreadsClauseOperand, allocateOperands, allocatorOperands,
2523
2547
reductionVars,
2524
2548
reductionDeclSymbols.empty ()
2525
2549
? nullptr
2526
2550
: mlir::ArrayAttr::get (converter.getFirOpBuilder ().getContext (),
2527
2551
reductionDeclSymbols),
2528
2552
procBindKindAttr);
2553
+
2554
+ llvm::SmallVector<mlir::Type> reductionTypes;
2555
+ reductionTypes.reserve (reductionVars.size ());
2556
+ llvm::transform (reductionVars, std::back_inserter (reductionTypes),
2557
+ [](mlir::Value v) { return v.getType (); });
2558
+ createBodyOfOp<mlir::omp::ParallelOp>(op, converter, currentLocation, eval,
2559
+ genNested, &clauseList, /* loopArgs=*/ {},
2560
+ outerCombined, /* dsp=*/ nullptr ,
2561
+ reductionSymbols, reductionTypes);
2562
+
2563
+ return op;
2529
2564
}
2530
2565
2531
2566
static mlir::omp::SectionOp
@@ -3517,10 +3552,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
3517
3552
break ;
3518
3553
}
3519
3554
3520
- if (singleDirective) {
3521
- genOpenMPReduction (converter, beginClauseList);
3555
+ if (singleDirective)
3522
3556
return ;
3523
- }
3524
3557
3525
3558
// Codegen for combined directives
3526
3559
bool combinedDirective = false ;
@@ -3556,7 +3589,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
3556
3589
" )" );
3557
3590
3558
3591
genNestedEvaluations (converter, eval);
3559
- genOpenMPReduction (converter, beginClauseList);
3560
3592
}
3561
3593
3562
3594
static void
0 commit comments