Skip to content

Commit be9f8ff

Browse files
[mlir][flang][openmp] Rework wsloop reduction operations (#80019)
This patch reworks the way that wsloop reduction operations function to better match the expected semantics from the OpenMP specification, following the rework of parallel reductions. The new semantics create a private reduction variable as a block argument which should be used normally for all operations on that variable in the region; this private variable is then combined with the others into the shared variable. This way no special omp.reduction operations are needed inside the region. These block arguments follow the loop control block arguments. --------- Co-authored-by: Kiran Chandramohan <[email protected]>
1 parent a69ecb2 commit be9f8ff

37 files changed

+2477
-1997
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3352,6 +3352,57 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
33523352
return args;
33533353
}
33543354

3355+
static llvm::SmallVector<const Fortran::semantics::Symbol *>
3356+
genLoopAndReductionVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3357+
mlir::Location &loc,
3358+
const llvm::SmallVector<const Fortran::semantics::Symbol *> &loopArgs,
3359+
const llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionArgs,
3360+
llvm::SmallVector<mlir::Type> &reductionTypes) {
3361+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3362+
3363+
llvm::SmallVector<mlir::Type> blockArgTypes;
3364+
llvm::SmallVector<mlir::Location> blockArgLocs;
3365+
blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
3366+
blockArgLocs.reserve(blockArgTypes.size());
3367+
mlir::Block *entryBlock;
3368+
3369+
if (loopArgs.size()) {
3370+
std::size_t loopVarTypeSize = 0;
3371+
for (const Fortran::semantics::Symbol *arg : loopArgs)
3372+
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
3373+
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
3374+
std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
3375+
loopVarType);
3376+
std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
3377+
}
3378+
if (reductionArgs.size()) {
3379+
llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
3380+
std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
3381+
}
3382+
entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes,
3383+
blockArgLocs);
3384+
// The argument is not currently in memory, so make a temporary for the
3385+
// argument, and store it there, then bind that location to the argument.
3386+
if (loopArgs.size()) {
3387+
mlir::Operation *storeOp = nullptr;
3388+
for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
3389+
mlir::Value indexVal =
3390+
fir::getBase(op->getRegion(0).front().getArgument(argIndex));
3391+
storeOp =
3392+
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
3393+
}
3394+
firOpBuilder.setInsertionPointAfter(storeOp);
3395+
}
3396+
// Bind the reduction arguments to their block arguments
3397+
for (auto [arg, prv] : llvm::zip_equal(
3398+
reductionArgs,
3399+
llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
3400+
converter.bindSymbol(*arg, prv);
3401+
}
3402+
3403+
return loopArgs;
3404+
}
3405+
33553406
static void
33563407
createSimdLoop(Fortran::lower::AbstractConverter &converter,
33573408
Fortran::semantics::SemanticsContext &semaCtx,
@@ -3429,6 +3480,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
34293480
llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
34303481
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
34313482
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
3483+
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
34323484
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
34333485
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
34343486
mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
@@ -3440,7 +3492,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
34403492
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
34413493
loopVarTypeSize);
34423494
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
3443-
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
3495+
cp.processReduction(loc, reductionVars, reductionDeclSymbols,
3496+
&reductionSymbols);
34443497
cp.processTODO<Fortran::parser::OmpClause::Linear,
34453498
Fortran::parser::OmpClause::Order>(loc, ompDirective);
34463499

@@ -3484,14 +3537,20 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
34843537
auto *nestedEval = getCollapsedLoopEval(
34853538
eval, Fortran::lower::getCollapseValue(beginClauseList));
34863539

3540+
llvm::SmallVector<mlir::Type> reductionTypes;
3541+
reductionTypes.reserve(reductionVars.size());
3542+
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
3543+
[](mlir::Value v) { return v.getType(); });
3544+
34873545
auto ivCallback = [&](mlir::Operation *op) {
3488-
return genLoopVars(op, converter, loc, iv);
3546+
return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols, reductionTypes);
34893547
};
34903548

34913549
createBodyOfOp<mlir::omp::WsLoopOp>(
34923550
wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
34933551
.setClauses(&beginClauseList)
34943552
.setDataSharingProcessor(&dsp)
3553+
.setReductions(&reductionSymbols, &reductionTypes)
34953554
.setGenRegionEntryCb(ivCallback));
34963555
}
34973556

@@ -3594,12 +3653,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
35943653
// 2.9.3.1 SIMD construct
35953654
createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
35963655
currentLocation);
3656+
genOpenMPReduction(converter, semaCtx, loopOpClauseList);
35973657
} else {
35983658
createWsLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
35993659
endClauseList, currentLocation);
36003660
}
3601-
3602-
genOpenMPReduction(converter, semaCtx, loopOpClauseList);
36033661
}
36043662

36053663
static void

flang/test/Fir/convert-to-llvm-openmp-and-fir.fir

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,17 @@ func.func @_QPsb() {
701701
// CHECK-SAME: %[[ARRAY_REF:.*]]: !llvm.ptr
702702
// CHECK: %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
703703
// CHECK: omp.parallel {
704-
// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] -> %[[RED_ACCUMULATOR]] : !llvm.ptr) for
704+
// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) for
705705
// CHECK: %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
706706
// CHECK: %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
707-
// CHECK: omp.reduction %[[ARRAY_ELEM]], %[[RED_ACCUMULATOR]] : i32, !llvm.ptr
707+
// CHECK: %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
708+
// CHECK: %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
709+
// CHECK: %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
710+
// CHECK: %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
711+
// CHECK: %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
712+
// CHECK: %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
713+
// CHECK: %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
714+
// CHECK: llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
708715
// CHECK: omp.yield
709716
// CHECK: omp.terminator
710717
// CHECK: llvm.return
@@ -733,15 +740,20 @@ func.func @_QPsimple_reduction(%arg0: !fir.ref<!fir.array<100x!fir.logical<4>>>
733740
%c1_i32 = arith.constant 1 : i32
734741
%c100_i32 = arith.constant 100 : i32
735742
%c1_i32_0 = arith.constant 1 : i32
736-
omp.wsloop reduction(@eqv_reduction -> %1 : !fir.ref<!fir.logical<4>>) for (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
743+
omp.wsloop reduction(@eqv_reduction %1 -> %prv : !fir.ref<!fir.logical<4>>) for (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
737744
fir.store %arg1 to %3 : !fir.ref<i32>
738745
%4 = fir.load %3 : !fir.ref<i32>
739746
%5 = fir.convert %4 : (i32) -> i64
740747
%c1_i64 = arith.constant 1 : i64
741748
%6 = arith.subi %5, %c1_i64 : i64
742749
%7 = fir.coordinate_of %arg0, %6 : (!fir.ref<!fir.array<100x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
743750
%8 = fir.load %7 : !fir.ref<!fir.logical<4>>
744-
omp.reduction %8, %1 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
751+
%lprv = fir.load %prv : !fir.ref<!fir.logical<4>>
752+
%lprv1 = fir.convert %lprv : (!fir.logical<4>) -> i1
753+
%9 = fir.convert %8 : (!fir.logical<4>) -> i1
754+
%10 = arith.cmpi eq, %9, %lprv1 : i1
755+
%11 = fir.convert %10 : (i1) -> !fir.logical<4>
756+
fir.store %11 to %prv : !fir.ref<!fir.logical<4>>
745757
omp.yield
746758
}
747759
omp.terminator

0 commit comments

Comments
 (0)