Skip to content

Commit 78d3a7f

Browse files
committed
[mlir][flang][openmp] Rework wsloop reduction operations
This patch reworks the way that wsloop reduction operations function to better match the expected semantics from the OpenMP specification, following the rework of parallel reductions. The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region. These block arguments follow the loop control block arguments.
1 parent 10f9807 commit 78d3a7f

35 files changed

+2392
-1994
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2274,6 +2274,12 @@ static void createBodyOfOp(
22742274
return undef.getDefiningOp();
22752275
};
22762276

2277+
llvm::SmallVector<mlir::Type> blockArgTypes;
2278+
llvm::SmallVector<mlir::Location> blockArgLocs;
2279+
blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
2280+
blockArgLocs.reserve(blockArgTypes.size());
2281+
mlir::Block *entryBlock;
2282+
22772283
// If an argument for the region is provided then create the block with that
22782284
// argument. Also update the symbol's address with the mlir argument value.
22792285
// e.g. For loops the argument is the induction variable. And all further
@@ -2283,11 +2289,21 @@ static void createBodyOfOp(
22832289
for (const Fortran::semantics::Symbol *arg : loopArgs)
22842290
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
22852291
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
2286-
llvm::SmallVector<mlir::Type> tiv(loopArgs.size(), loopVarType);
2287-
llvm::SmallVector<mlir::Location> locs(loopArgs.size(), loc);
2288-
firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
2289-
// The argument is not currently in memory, so make a temporary for the
2290-
// argument, and store it there, then bind that location to the argument.
2292+
std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
2293+
loopVarType);
2294+
std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
2295+
}
2296+
if (reductionArgs.size()) {
2297+
llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
2298+
std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
2299+
}
2300+
2301+
entryBlock = firOpBuilder.createBlock(&op.getRegion(), {}, blockArgTypes,
2302+
blockArgLocs);
2303+
2304+
// The argument is not currently in memory, so make a temporary for the
2305+
// argument, and store it there, then bind that location to the argument.
2306+
if (loopArgs.size()) {
22912307
mlir::Operation *storeOp = nullptr;
22922308
for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
22932309
mlir::Value indexVal =
@@ -2296,16 +2312,12 @@ static void createBodyOfOp(
22962312
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
22972313
}
22982314
firOpBuilder.setInsertionPointAfter(storeOp);
2299-
} else if (reductionArgs.size()) {
2300-
llvm::SmallVector<mlir::Location> locs(reductionArgs.size(), loc);
2301-
auto block =
2302-
firOpBuilder.createBlock(&op.getRegion(), {}, reductionTypes, locs);
2303-
for (auto [arg, prv] :
2304-
llvm::zip_equal(reductionArgs, block->getArguments())) {
2305-
converter.bindSymbol(*arg, prv);
2306-
}
2307-
} else {
2308-
firOpBuilder.createBlock(&op.getRegion());
2315+
}
2316+
// Bind the reduction arguments to their block arguments
2317+
for (auto [arg, prv] : llvm::zip_equal(
2318+
reductionArgs,
2319+
llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
2320+
converter.bindSymbol(*arg, prv);
23092321
}
23102322

23112323
// Mark the earliest insertion point.
@@ -3293,6 +3305,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
32933305
llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
32943306
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
32953307
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
3308+
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
32963309
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
32973310
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
32983311
mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
@@ -3304,7 +3317,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
33043317
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
33053318
loopVarTypeSize);
33063319
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
3307-
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
3320+
cp.processReduction(loc, reductionVars, reductionDeclSymbols,
3321+
&reductionSymbols);
33083322
cp.processTODO<Fortran::parser::OmpClause::Linear,
33093323
Fortran::parser::OmpClause::Order>(loc, ompDirective);
33103324

@@ -3347,9 +3361,14 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
33473361

33483362
auto *nestedEval = getCollapsedLoopEval(
33493363
eval, Fortran::lower::getCollapseValue(beginClauseList));
3364+
llvm::SmallVector<mlir::Type> reductionTypes;
3365+
reductionTypes.reserve(reductionVars.size());
3366+
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
3367+
[](mlir::Value v) { return v.getType(); });
33503368
createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, *nestedEval,
33513369
/*genNested=*/true, &beginClauseList, iv,
3352-
/*outer=*/false, &dsp);
3370+
/*outer=*/false, &dsp, reductionSymbols,
3371+
reductionTypes);
33533372
}
33543373

33553374
static void createSimdWsLoop(
@@ -3450,12 +3469,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
34503469
// 2.9.3.1 SIMD construct
34513470
createSimdLoop(converter, eval, ompDirective, loopOpClauseList,
34523471
currentLocation);
3472+
genOpenMPReduction(converter, loopOpClauseList);
34533473
} else {
34543474
createWsLoop(converter, eval, ompDirective, loopOpClauseList, endClauseList,
34553475
currentLocation);
34563476
}
3457-
3458-
genOpenMPReduction(converter, loopOpClauseList);
34593477
}
34603478

34613479
static void

flang/test/Fir/convert-to-llvm-openmp-and-fir.fir

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,17 @@ func.func @_QPsb() {
701701
// CHECK-SAME: %[[ARRAY_REF:.*]]: !llvm.ptr
702702
// CHECK: %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
703703
// CHECK: omp.parallel {
704-
// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] -> %[[RED_ACCUMULATOR]] : !llvm.ptr) for
704+
// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) for
705705
// CHECK: %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
706706
// CHECK: %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
707-
// CHECK: omp.reduction %[[ARRAY_ELEM]], %[[RED_ACCUMULATOR]] : i32, !llvm.ptr
707+
// CHECK: %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
708+
// CHECK: %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
709+
// CHECK: %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
710+
// CHECK: %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
711+
// CHECK: %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
712+
// CHECK: %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
713+
// CHECK: %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
714+
// CHECK: llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
708715
// CHECK: omp.yield
709716
// CHECK: omp.terminator
710717
// CHECK: llvm.return
@@ -733,15 +740,20 @@ func.func @_QPsimple_reduction(%arg0: !fir.ref<!fir.array<100x!fir.logical<4>>>
733740
%c1_i32 = arith.constant 1 : i32
734741
%c100_i32 = arith.constant 100 : i32
735742
%c1_i32_0 = arith.constant 1 : i32
736-
omp.wsloop reduction(@eqv_reduction -> %1 : !fir.ref<!fir.logical<4>>) for (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
743+
omp.wsloop reduction(@eqv_reduction %1 -> %prv : !fir.ref<!fir.logical<4>>) for (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
737744
fir.store %arg1 to %3 : !fir.ref<i32>
738745
%4 = fir.load %3 : !fir.ref<i32>
739746
%5 = fir.convert %4 : (i32) -> i64
740747
%c1_i64 = arith.constant 1 : i64
741748
%6 = arith.subi %5, %c1_i64 : i64
742749
%7 = fir.coordinate_of %arg0, %6 : (!fir.ref<!fir.array<100x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
743750
%8 = fir.load %7 : !fir.ref<!fir.logical<4>>
744-
omp.reduction %8, %1 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
751+
%lprv = fir.load %prv : !fir.ref<!fir.logical<4>>
752+
%lprv1 = fir.convert %lprv : (!fir.logical<4>) -> i1
753+
%9 = fir.convert %8 : (!fir.logical<4>) -> i1
754+
%10 = arith.cmpi eq, %9, %lprv1 : i1
755+
%11 = fir.convert %10 : (i1) -> !fir.logical<4>
756+
fir.store %11 to %prv : !fir.ref<!fir.logical<4>>
745757
omp.yield
746758
}
747759
omp.terminator

0 commit comments

Comments
 (0)