Skip to content

[flang][OpenMP] Add reduction clause support to loop directive #128849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "mlir/Transforms/DialectConversion.h"

#include <memory>
#include <optional>
#include <type_traits>

namespace flangomp {
#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
Expand Down Expand Up @@ -58,7 +60,7 @@ class GenericLoopConversionPattern
if (teamsLoopCanBeParallelFor(loopOp))
rewriteToDistributeParallelDo(loopOp, rewriter);
else
rewriteToDistrbute(loopOp, rewriter);
rewriteToDistribute(loopOp, rewriter);
break;
}

Expand All @@ -77,9 +79,6 @@ class GenericLoopConversionPattern
if (loopOp.getOrder())
return todo("order");

if (!loopOp.getReductionVars().empty())
return todo("reduction");

return mlir::success();
}

Expand Down Expand Up @@ -168,7 +167,7 @@ class GenericLoopConversionPattern
case ClauseBindKind::Parallel:
return rewriteToWsloop(loopOp, rewriter);
case ClauseBindKind::Teams:
return rewriteToDistrbute(loopOp, rewriter);
return rewriteToDistribute(loopOp, rewriter);
case ClauseBindKind::Thread:
return rewriteToSimdLoop(loopOp, rewriter);
}
Expand Down Expand Up @@ -211,8 +210,9 @@ class GenericLoopConversionPattern
loopOp, rewriter);
}

void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
mlir::ConversionPatternRewriter &rewriter) const {
void rewriteToDistribute(mlir::omp::LoopOp loopOp,
mlir::ConversionPatternRewriter &rewriter) const {
assert(loopOp.getReductionVars().empty());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing the compound directive splitting in the frontend already makes sure that !$omp teams loop reduction(...) only adds the reduction clause to the teams leaf construct, so that this assert is not triggered.

We probably want to make that check in the MLIR verifier, though, since we can't rely on Flang lowering being the only source of OpenMP dialect operations. If it's an omp.loop taking the implicit role of an omp.distribute, then we need to ensure it doesn't define a reduction clause.

We'd probably want to move the checks in GenericLoopConversionPattern::rewriteStandaloneLoop, GenericLoopConversionPattern::matchAndRewrite and related utility functions from here to an extraClassDeclaration of mlir::omp::LoopOp to do this properly. Something like LoopOpRole mlir::omp::LoopOp::getLoopRole(), where the possible values are simd, wsloop, distribute and distribute parallel wsloop.

This can come as a follow-up PR, not something to be implemented here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a semantic check for this: #128823. Can look into the verifier improment in a later PR, makes sense.

rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
mlir::omp::DistributeOperands>(loopOp, rewriter);
}
Expand Down Expand Up @@ -246,6 +246,12 @@ class GenericLoopConversionPattern
Fortran::common::openmp::EntryBlockArgs args;
args.priv.vars = clauseOps.privateVars;

if constexpr (!std::is_same_v<OpOperandsTy,
mlir::omp::DistributeOperands>) {
populateReductionClauseOps(loopOp, clauseOps);
args.reduction.vars = clauseOps.reductionVars;
}

auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());

Expand Down Expand Up @@ -275,8 +281,7 @@ class GenericLoopConversionPattern

auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
parallelClauseOps);
mlir::Block *parallelBlock =
genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
parallelOp.setComposite(true);
rewriter.setInsertionPoint(
rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
Expand All @@ -288,20 +293,54 @@ class GenericLoopConversionPattern
rewriter.createBlock(&distributeOp.getRegion());

mlir::omp::WsloopOperands wsloopClauseOps;
populateReductionClauseOps(loopOp, wsloopClauseOps);
Fortran::common::openmp::EntryBlockArgs wsloopArgs;
wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;

auto wsloopOp =
rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
wsloopOp.setComposite(true);
rewriter.createBlock(&wsloopOp.getRegion());
genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion());

mlir::IRMapping mapper;
mlir::Block &loopBlock = *loopOp.getRegion().begin();

for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
loopBlock.getArguments(), parallelBlock->getArguments()))
auto loopBlockInterface =
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
auto parallelBlockInterface =
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*parallelOp);
auto wsloopBlockInterface =
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*wsloopOp);

for (auto [loopOpArg, parallelOpArg] :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Please add a comment to explain why we can match private entry block arguments to all of the entry block arguments of omp.parallel, and the same below with respect to reduction and omp.loop.

My concern is that, if at some point we introduce any other entry block arguments to the omp.parallel or omp.loop operations, this here will break and troubleshooting it won't be straightforward.

Actually, I think that casting these ops to their BlockArgOpenMPOpInterface and accessing their arguments via get{Private,Reduction}BlockArgs makes this much more resilient to future updates.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used the BlockArgOpenMPInterface to access the relevant args.

llvm::zip_equal(loopBlockInterface.getPrivateBlockArgs(),
parallelBlockInterface.getPrivateBlockArgs()))
mapper.map(loopOpArg, parallelOpArg);

for (auto [loopOpArg, wsloopOpArg] :
llvm::zip_equal(loopBlockInterface.getReductionBlockArgs(),
wsloopBlockInterface.getReductionBlockArgs()))
mapper.map(loopOpArg, wsloopOpArg);

rewriter.clone(*loopOp.begin(), mapper);
}

void
populateReductionClauseOps(mlir::omp::LoopOp loopOp,
mlir::omp::ReductionClauseOps &clauseOps) const {
clauseOps.reductionMod = loopOp.getReductionModAttr();
clauseOps.reductionVars = loopOp.getReductionVars();

std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms();
if (reductionSyms)
clauseOps.reductionSyms.assign(reductionSyms->begin(),
reductionSyms->end());

std::optional<llvm::ArrayRef<bool>> reductionByref =
loopOp.getReductionByref();
if (reductionByref)
clauseOps.reductionByref.assign(reductionByref->begin(),
reductionByref->end());
}
};

class GenericLoopConversionPass
Expand Down
45 changes: 44 additions & 1 deletion flang/test/Lower/OpenMP/loop-directive.f90
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ subroutine test_order()
subroutine test_reduction()
integer :: i, dummy = 1

! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
! CHECK: omp.simd private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
! CHECK-SAME: (@[[RED]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]] : !{{.*}}) {
! CHECK-NEXT: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
! CHECK: %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_reductionEdummy"}
Expand Down Expand Up @@ -294,3 +294,46 @@ subroutine teams_loop_cannot_be_parallel_for_4
!$omp end parallel
END DO
end subroutine

! CHECK-LABEL: func.func @_QPloop_parallel_bind_reduction
subroutine loop_parallel_bind_reduction
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add one of these reduction tests for the distribute parallel do case? That way we test the rewriteToDistributeParallelDo function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

implicit none
integer :: x, i

! CHECK: omp.wsloop
! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>)
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
! CHECK-NEXT: omp.loop_nest {{.*}} {
! CHECK-NEXT: hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
! CHECK-NEXT: hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
! CHECK: }
! CHECK: }
!$omp loop bind(parallel) reduction(+: x)
do i = 0, 10
x = x + i
end do
end subroutine

! CHECK-LABEL: func.func @_QPloop_teams_loop_reduction
subroutine loop_teams_loop_reduction
implicit none
integer :: x, i
! CHECK: omp.teams {
! CHECK: omp.parallel
! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>) {
! CHECK: omp.distribute {
! CHECK: omp.wsloop
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
! CHECK-NEXT: omp.loop_nest {{.*}} {
! CHECK-NEXT: hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
! CHECK-NEXT: hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
! CHECK: }
! CHECK: }
! CHECK: }
! CHECK: }
! CHECK: }
!$omp teams loop reduction(+: x)
do i = 0, 10
x = x + i
end do
end subroutine
16 changes: 2 additions & 14 deletions flang/test/Transforms/generic-loop-rewriting-todo.mlir
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
// RUN: fir-opt --omp-generic-loop-conversion -verify-diagnostics %s

omp.declare_reduction @add_reduction_i32 : i32 init {
^bb0(%arg0: i32):
%c0_i32 = arith.constant 0 : i32
omp.yield(%c0_i32 : i32)
} combiner {
^bb0(%arg0: i32, %arg1: i32):
%0 = arith.addi %arg0, %arg1 : i32
omp.yield(%0 : i32)
}

func.func @_QPloop_order() {
omp.teams {
%c0 = arith.constant 0 : i32
%c10 = arith.constant 10 : i32
%c1 = arith.constant 1 : i32
%sum = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_orderEi"}

// expected-error@below {{not yet implemented: Unhandled clause reduction in omp.loop operation}}
omp.loop reduction(@add_reduction_i32 %sum -> %arg2 : !fir.ref<i32>) {
// expected-error@below {{not yet implemented: Unhandled clause order in omp.loop operation}}
omp.loop order(reproducible:concurrent) {
omp.loop_nest (%arg3) : i32 = (%c0) to (%c10) inclusive step (%c1) {
omp.yield
}
Expand Down