@@ -2366,12 +2366,6 @@ static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) {
2366
2366
return undef.getDefiningOp ();
2367
2367
};
2368
2368
2369
- llvm::SmallVector<mlir::Type> blockArgTypes;
2370
- llvm::SmallVector<mlir::Location> blockArgLocs;
2371
- blockArgTypes.reserve (loopArgs.size () + reductionArgs.size ());
2372
- blockArgLocs.reserve (blockArgTypes.size ());
2373
- mlir::Block *entryBlock;
2374
-
2375
2369
// If an argument for the region is provided then create the block with that
2376
2370
// argument. Also update the symbol's address with the mlir argument value.
2377
2371
// e.g. For loops the argument is the induction variable. And all further
@@ -3358,6 +3352,57 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3358
3352
return args;
3359
3353
}
3360
3354
3355
+ static llvm::SmallVector<const Fortran::semantics::Symbol *>
3356
+ genLoopAndReductionVars (mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3357
+ mlir::Location &loc,
3358
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs,
3359
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs,
3360
+ llvm::SmallVector<mlir::Type> &reductionTypes) {
3361
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
3362
+
3363
+ llvm::SmallVector<mlir::Type> blockArgTypes;
3364
+ llvm::SmallVector<mlir::Location> blockArgLocs;
3365
+ blockArgTypes.reserve (loopArgs.size () + reductionArgs.size ());
3366
+ blockArgLocs.reserve (blockArgTypes.size ());
3367
+ mlir::Block *entryBlock;
3368
+
3369
+ if (loopArgs.size ()) {
3370
+ std::size_t loopVarTypeSize = 0 ;
3371
+ for (const Fortran::semantics::Symbol *arg : loopArgs)
3372
+ loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
3373
+ mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
3374
+ std::fill_n (std::back_inserter (blockArgTypes), loopArgs.size (),
3375
+ loopVarType);
3376
+ std::fill_n (std::back_inserter (blockArgLocs), loopArgs.size (), loc);
3377
+ }
3378
+ if (reductionArgs.size ()) {
3379
+ llvm::copy (reductionTypes, std::back_inserter (blockArgTypes));
3380
+ std::fill_n (std::back_inserter (blockArgLocs), reductionArgs.size (), loc);
3381
+ }
3382
+ entryBlock = firOpBuilder.createBlock (&op->getRegion (0 ), {}, blockArgTypes,
3383
+ blockArgLocs);
3384
+ // The argument is not currently in memory, so make a temporary for the
3385
+ // argument, and store it there, then bind that location to the argument.
3386
+ if (loopArgs.size ()) {
3387
+ mlir::Operation *storeOp = nullptr ;
3388
+ for (auto [argIndex, argSymbol] : llvm::enumerate (loopArgs)) {
3389
+ mlir::Value indexVal =
3390
+ fir::getBase (op->getRegion (0 ).front ().getArgument (argIndex));
3391
+ storeOp =
3392
+ createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
3393
+ }
3394
+ firOpBuilder.setInsertionPointAfter (storeOp);
3395
+ }
3396
+ // Bind the reduction arguments to their block arguments
3397
+ for (auto [arg, prv] : llvm::zip_equal (
3398
+ reductionArgs,
3399
+ llvm::drop_begin (entryBlock->getArguments (), loopArgs.size ()))) {
3400
+ converter.bindSymbol (*arg, prv);
3401
+ }
3402
+
3403
+ return loopArgs;
3404
+ }
3405
+
3361
3406
static void
3362
3407
createSimdLoop (Fortran::lower::AbstractConverter &converter,
3363
3408
Fortran::semantics::SemanticsContext &semaCtx,
@@ -3492,19 +3537,20 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
3492
3537
auto *nestedEval = getCollapsedLoopEval (
3493
3538
eval, Fortran::lower::getCollapseValue (beginClauseList));
3494
3539
3540
+ llvm::SmallVector<mlir::Type> reductionTypes;
3541
+ reductionTypes.reserve (reductionVars.size ());
3542
+ llvm::transform (reductionVars, std::back_inserter (reductionTypes),
3543
+ [](mlir::Value v) { return v.getType (); });
3544
+
3495
3545
auto ivCallback = [&](mlir::Operation *op) {
3496
- return genLoopVars (op, converter, loc, iv);
3546
+ return genLoopAndReductionVars (op, converter, loc, iv, reductionSymbols, reductionTypes );
3497
3547
};
3498
3548
3499
- // llvm::SmallVector<mlir::Type> reductionTypes;
3500
- // reductionTypes.reserve(reductionVars.size());
3501
- // llvm::transform(reductionVars, std::back_inserter(reductionTypes),
3502
- // [](mlir::Value v) { return v.getType(); });
3503
-
3504
3549
createBodyOfOp<mlir::omp::WsLoopOp>(
3505
3550
wsLoopOp, OpWithBodyGenInfo (converter, semaCtx, loc, *nestedEval)
3506
3551
.setClauses (&beginClauseList)
3507
3552
.setDataSharingProcessor (&dsp)
3553
+ .setReductions (&reductionSymbols, &reductionTypes)
3508
3554
.setGenRegionEntryCb (ivCallback));
3509
3555
}
3510
3556
0 commit comments