@@ -3352,6 +3352,57 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3352
3352
return args;
3353
3353
}
3354
3354
3355
+ static llvm::SmallVector<const Fortran::semantics::Symbol *>
3356
+ genLoopAndReductionVars (mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3357
+ mlir::Location &loc,
3358
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs,
3359
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs,
3360
+ llvm::SmallVector<mlir::Type> &reductionTypes) {
3361
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
3362
+
3363
+ llvm::SmallVector<mlir::Type> blockArgTypes;
3364
+ llvm::SmallVector<mlir::Location> blockArgLocs;
3365
+ blockArgTypes.reserve (loopArgs.size () + reductionArgs.size ());
3366
+ blockArgLocs.reserve (blockArgTypes.size ());
3367
+ mlir::Block *entryBlock;
3368
+
3369
+ if (loopArgs.size ()) {
3370
+ std::size_t loopVarTypeSize = 0 ;
3371
+ for (const Fortran::semantics::Symbol *arg : loopArgs)
3372
+ loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
3373
+ mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
3374
+ std::fill_n (std::back_inserter (blockArgTypes), loopArgs.size (),
3375
+ loopVarType);
3376
+ std::fill_n (std::back_inserter (blockArgLocs), loopArgs.size (), loc);
3377
+ }
3378
+ if (reductionArgs.size ()) {
3379
+ llvm::copy (reductionTypes, std::back_inserter (blockArgTypes));
3380
+ std::fill_n (std::back_inserter (blockArgLocs), reductionArgs.size (), loc);
3381
+ }
3382
+ entryBlock = firOpBuilder.createBlock (&op->getRegion (0 ), {}, blockArgTypes,
3383
+ blockArgLocs);
3384
+ // The argument is not currently in memory, so make a temporary for the
3385
+ // argument, and store it there, then bind that location to the argument.
3386
+ if (loopArgs.size ()) {
3387
+ mlir::Operation *storeOp = nullptr ;
3388
+ for (auto [argIndex, argSymbol] : llvm::enumerate (loopArgs)) {
3389
+ mlir::Value indexVal =
3390
+ fir::getBase (op->getRegion (0 ).front ().getArgument (argIndex));
3391
+ storeOp =
3392
+ createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
3393
+ }
3394
+ firOpBuilder.setInsertionPointAfter (storeOp);
3395
+ }
3396
+ // Bind the reduction arguments to their block arguments
3397
+ for (auto [arg, prv] : llvm::zip_equal (
3398
+ reductionArgs,
3399
+ llvm::drop_begin (entryBlock->getArguments (), loopArgs.size ()))) {
3400
+ converter.bindSymbol (*arg, prv);
3401
+ }
3402
+
3403
+ return loopArgs;
3404
+ }
3405
+
3355
3406
static void
3356
3407
createSimdLoop (Fortran::lower::AbstractConverter &converter,
3357
3408
Fortran::semantics::SemanticsContext &semaCtx,
@@ -3429,6 +3480,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
3429
3480
llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
3430
3481
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
3431
3482
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
3483
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
3432
3484
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
3433
3485
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
3434
3486
mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
@@ -3440,7 +3492,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
3440
3492
cp.processCollapse (loc, eval, lowerBound, upperBound, step, iv,
3441
3493
loopVarTypeSize);
3442
3494
cp.processScheduleChunk (stmtCtx, scheduleChunkClauseOperand);
3443
- cp.processReduction (loc, reductionVars, reductionDeclSymbols);
3495
+ cp.processReduction (loc, reductionVars, reductionDeclSymbols,
3496
+ &reductionSymbols);
3444
3497
cp.processTODO <Fortran::parser::OmpClause::Linear,
3445
3498
Fortran::parser::OmpClause::Order>(loc, ompDirective);
3446
3499
@@ -3484,14 +3537,20 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
3484
3537
auto *nestedEval = getCollapsedLoopEval (
3485
3538
eval, Fortran::lower::getCollapseValue (beginClauseList));
3486
3539
3540
+ llvm::SmallVector<mlir::Type> reductionTypes;
3541
+ reductionTypes.reserve (reductionVars.size ());
3542
+ llvm::transform (reductionVars, std::back_inserter (reductionTypes),
3543
+ [](mlir::Value v) { return v.getType (); });
3544
+
3487
3545
auto ivCallback = [&](mlir::Operation *op) {
3488
- return genLoopVars (op, converter, loc, iv);
3546
+ return genLoopAndReductionVars (op, converter, loc, iv, reductionSymbols, reductionTypes );
3489
3547
};
3490
3548
3491
3549
createBodyOfOp<mlir::omp::WsLoopOp>(
3492
3550
wsLoopOp, OpWithBodyGenInfo (converter, semaCtx, loc, *nestedEval)
3493
3551
.setClauses (&beginClauseList)
3494
3552
.setDataSharingProcessor (&dsp)
3553
+ .setReductions (&reductionSymbols, &reductionTypes)
3495
3554
.setGenRegionEntryCb (ivCallback));
3496
3555
}
3497
3556
@@ -3594,12 +3653,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
3594
3653
// 2.9.3.1 SIMD construct
3595
3654
createSimdLoop (converter, semaCtx, eval, ompDirective, loopOpClauseList,
3596
3655
currentLocation);
3656
+ genOpenMPReduction (converter, semaCtx, loopOpClauseList);
3597
3657
} else {
3598
3658
createWsLoop (converter, semaCtx, eval, ompDirective, loopOpClauseList,
3599
3659
endClauseList, currentLocation);
3600
3660
}
3601
-
3602
- genOpenMPReduction (converter, semaCtx, loopOpClauseList);
3603
3661
}
3604
3662
3605
3663
static void
0 commit comments