@@ -2274,6 +2274,12 @@ static void createBodyOfOp(
2274
2274
return undef.getDefiningOp ();
2275
2275
};
2276
2276
2277
+ llvm::SmallVector<mlir::Type> blockArgTypes;
2278
+ llvm::SmallVector<mlir::Location> blockArgLocs;
2279
+ blockArgTypes.reserve (loopArgs.size () + reductionArgs.size ());
2280
+ blockArgLocs.reserve (blockArgTypes.size ());
2281
+ mlir::Block *entryBlock;
2282
+
2277
2283
// If an argument for the region is provided then create the block with that
2278
2284
// argument. Also update the symbol's address with the mlir argument value.
2279
2285
// e.g. For loops the argument is the induction variable. And all further
@@ -2283,11 +2289,21 @@ static void createBodyOfOp(
2283
2289
for (const Fortran::semantics::Symbol *arg : loopArgs)
2284
2290
loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
2285
2291
mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
2286
- llvm::SmallVector<mlir::Type> tiv (loopArgs.size (), loopVarType);
2287
- llvm::SmallVector<mlir::Location> locs (loopArgs.size (), loc);
2288
- firOpBuilder.createBlock (&op.getRegion (), {}, tiv, locs);
2289
- // The argument is not currently in memory, so make a temporary for the
2290
- // argument, and store it there, then bind that location to the argument.
2292
+ std::fill_n (std::back_inserter (blockArgTypes), loopArgs.size (),
2293
+ loopVarType);
2294
+ std::fill_n (std::back_inserter (blockArgLocs), loopArgs.size (), loc);
2295
+ }
2296
+ if (reductionArgs.size ()) {
2297
+ llvm::copy (reductionTypes, std::back_inserter (blockArgTypes));
2298
+ std::fill_n (std::back_inserter (blockArgLocs), reductionArgs.size (), loc);
2299
+ }
2300
+
2301
+ entryBlock = firOpBuilder.createBlock (&op.getRegion (), {}, blockArgTypes,
2302
+ blockArgLocs);
2303
+
2304
+ // The argument is not currently in memory, so make a temporary for the
2305
+ // argument, and store it there, then bind that location to the argument.
2306
+ if (loopArgs.size ()) {
2291
2307
mlir::Operation *storeOp = nullptr ;
2292
2308
for (auto [argIndex, argSymbol] : llvm::enumerate (loopArgs)) {
2293
2309
mlir::Value indexVal =
@@ -2296,16 +2312,12 @@ static void createBodyOfOp(
2296
2312
createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
2297
2313
}
2298
2314
firOpBuilder.setInsertionPointAfter (storeOp);
2299
- } else if (reductionArgs.size ()) {
2300
- llvm::SmallVector<mlir::Location> locs (reductionArgs.size (), loc);
2301
- auto block =
2302
- firOpBuilder.createBlock (&op.getRegion (), {}, reductionTypes, locs);
2303
- for (auto [arg, prv] :
2304
- llvm::zip_equal (reductionArgs, block->getArguments ())) {
2305
- converter.bindSymbol (*arg, prv);
2306
- }
2307
- } else {
2308
- firOpBuilder.createBlock (&op.getRegion ());
2315
+ }
2316
+ // Bind the reduction arguments to their block arguments
2317
+ for (auto [arg, prv] : llvm::zip_equal (
2318
+ reductionArgs,
2319
+ llvm::drop_begin (entryBlock->getArguments (), loopArgs.size ()))) {
2320
+ converter.bindSymbol (*arg, prv);
2309
2321
}
2310
2322
2311
2323
// Mark the earliest insertion point.
@@ -3293,6 +3305,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
3293
3305
llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
3294
3306
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
3295
3307
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
3308
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
3296
3309
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
3297
3310
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
3298
3311
mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
@@ -3304,7 +3317,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
3304
3317
cp.processCollapse (loc, eval, lowerBound, upperBound, step, iv,
3305
3318
loopVarTypeSize);
3306
3319
cp.processScheduleChunk (stmtCtx, scheduleChunkClauseOperand);
3307
- cp.processReduction (loc, reductionVars, reductionDeclSymbols);
3320
+ cp.processReduction (loc, reductionVars, reductionDeclSymbols,
3321
+ &reductionSymbols);
3308
3322
cp.processTODO <Fortran::parser::OmpClause::Linear,
3309
3323
Fortran::parser::OmpClause::Order>(loc, ompDirective);
3310
3324
@@ -3347,9 +3361,14 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
3347
3361
3348
3362
auto *nestedEval = getCollapsedLoopEval (
3349
3363
eval, Fortran::lower::getCollapseValue (beginClauseList));
3364
+ llvm::SmallVector<mlir::Type> reductionTypes;
3365
+ reductionTypes.reserve (reductionVars.size ());
3366
+ llvm::transform (reductionVars, std::back_inserter (reductionTypes),
3367
+ [](mlir::Value v) { return v.getType (); });
3350
3368
createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, *nestedEval,
3351
3369
/* genNested=*/ true , &beginClauseList, iv,
3352
- /* outer=*/ false , &dsp);
3370
+ /* outer=*/ false , &dsp, reductionSymbols,
3371
+ reductionTypes);
3353
3372
}
3354
3373
3355
3374
static void createSimdWsLoop (
@@ -3450,12 +3469,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
3450
3469
// 2.9.3.1 SIMD construct
3451
3470
createSimdLoop (converter, eval, ompDirective, loopOpClauseList,
3452
3471
currentLocation);
3472
+ genOpenMPReduction (converter, loopOpClauseList);
3453
3473
} else {
3454
3474
createWsLoop (converter, eval, ompDirective, loopOpClauseList, endClauseList,
3455
3475
currentLocation);
3456
3476
}
3457
-
3458
- genOpenMPReduction (converter, loopOpClauseList);
3459
3477
}
3460
3478
3461
3479
static void
0 commit comments