Skip to content

Commit e0c6909

Browse files
authored
[flang][OpenMP] Add reduction clause support to loop directive (#128849)
Extends `loop` directive transformation by adding support for the `reduction` clause.
1 parent 55f2547 commit e0c6909

File tree

3 files changed

+98
-28
lines changed

3 files changed

+98
-28
lines changed

flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "mlir/Transforms/DialectConversion.h"
1616

1717
#include <memory>
18+
#include <optional>
19+
#include <type_traits>
1820

1921
namespace flangomp {
2022
#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
@@ -58,7 +60,7 @@ class GenericLoopConversionPattern
5860
if (teamsLoopCanBeParallelFor(loopOp))
5961
rewriteToDistributeParallelDo(loopOp, rewriter);
6062
else
61-
rewriteToDistrbute(loopOp, rewriter);
63+
rewriteToDistribute(loopOp, rewriter);
6264
break;
6365
}
6466

@@ -77,9 +79,6 @@ class GenericLoopConversionPattern
7779
if (loopOp.getOrder())
7880
return todo("order");
7981

80-
if (!loopOp.getReductionVars().empty())
81-
return todo("reduction");
82-
8382
return mlir::success();
8483
}
8584

@@ -168,7 +167,7 @@ class GenericLoopConversionPattern
168167
case ClauseBindKind::Parallel:
169168
return rewriteToWsloop(loopOp, rewriter);
170169
case ClauseBindKind::Teams:
171-
return rewriteToDistrbute(loopOp, rewriter);
170+
return rewriteToDistribute(loopOp, rewriter);
172171
case ClauseBindKind::Thread:
173172
return rewriteToSimdLoop(loopOp, rewriter);
174173
}
@@ -211,8 +210,9 @@ class GenericLoopConversionPattern
211210
loopOp, rewriter);
212211
}
213212

214-
void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
215-
mlir::ConversionPatternRewriter &rewriter) const {
213+
void rewriteToDistribute(mlir::omp::LoopOp loopOp,
214+
mlir::ConversionPatternRewriter &rewriter) const {
215+
assert(loopOp.getReductionVars().empty());
216216
rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
217217
mlir::omp::DistributeOperands>(loopOp, rewriter);
218218
}
@@ -246,6 +246,12 @@ class GenericLoopConversionPattern
246246
Fortran::common::openmp::EntryBlockArgs args;
247247
args.priv.vars = clauseOps.privateVars;
248248

249+
if constexpr (!std::is_same_v<OpOperandsTy,
250+
mlir::omp::DistributeOperands>) {
251+
populateReductionClauseOps(loopOp, clauseOps);
252+
args.reduction.vars = clauseOps.reductionVars;
253+
}
254+
249255
auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
250256
mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
251257

@@ -275,8 +281,7 @@ class GenericLoopConversionPattern
275281

276282
auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
277283
parallelClauseOps);
278-
mlir::Block *parallelBlock =
279-
genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
284+
genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
280285
parallelOp.setComposite(true);
281286
rewriter.setInsertionPoint(
282287
rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
@@ -288,20 +293,54 @@ class GenericLoopConversionPattern
288293
rewriter.createBlock(&distributeOp.getRegion());
289294

290295
mlir::omp::WsloopOperands wsloopClauseOps;
296+
populateReductionClauseOps(loopOp, wsloopClauseOps);
297+
Fortran::common::openmp::EntryBlockArgs wsloopArgs;
298+
wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
299+
291300
auto wsloopOp =
292301
rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
293302
wsloopOp.setComposite(true);
294-
rewriter.createBlock(&wsloopOp.getRegion());
303+
genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion());
295304

296305
mlir::IRMapping mapper;
297-
mlir::Block &loopBlock = *loopOp.getRegion().begin();
298306

299-
for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
300-
loopBlock.getArguments(), parallelBlock->getArguments()))
307+
auto loopBlockInterface =
308+
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
309+
auto parallelBlockInterface =
310+
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*parallelOp);
311+
auto wsloopBlockInterface =
312+
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*wsloopOp);
313+
314+
for (auto [loopOpArg, parallelOpArg] :
315+
llvm::zip_equal(loopBlockInterface.getPrivateBlockArgs(),
316+
parallelBlockInterface.getPrivateBlockArgs()))
301317
mapper.map(loopOpArg, parallelOpArg);
302318

319+
for (auto [loopOpArg, wsloopOpArg] :
320+
llvm::zip_equal(loopBlockInterface.getReductionBlockArgs(),
321+
wsloopBlockInterface.getReductionBlockArgs()))
322+
mapper.map(loopOpArg, wsloopOpArg);
323+
303324
rewriter.clone(*loopOp.begin(), mapper);
304325
}
326+
327+
void
328+
populateReductionClauseOps(mlir::omp::LoopOp loopOp,
329+
mlir::omp::ReductionClauseOps &clauseOps) const {
330+
clauseOps.reductionMod = loopOp.getReductionModAttr();
331+
clauseOps.reductionVars = loopOp.getReductionVars();
332+
333+
std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms();
334+
if (reductionSyms)
335+
clauseOps.reductionSyms.assign(reductionSyms->begin(),
336+
reductionSyms->end());
337+
338+
std::optional<llvm::ArrayRef<bool>> reductionByref =
339+
loopOp.getReductionByref();
340+
if (reductionByref)
341+
clauseOps.reductionByref.assign(reductionByref->begin(),
342+
reductionByref->end());
343+
}
305344
};
306345

307346
class GenericLoopConversionPass

flang/test/Lower/OpenMP/loop-directive.f90

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ subroutine test_order()
7575
subroutine test_reduction()
7676
integer :: i, dummy = 1
7777

78-
! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
78+
! CHECK: omp.simd private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
7979
! CHECK-SAME: (@[[RED]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]] : !{{.*}}) {
8080
! CHECK-NEXT: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
8181
! CHECK: %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_reductionEdummy"}
@@ -294,3 +294,46 @@ subroutine teams_loop_cannot_be_parallel_for_4
294294
!$omp end parallel
295295
END DO
296296
end subroutine
297+
298+
! CHECK-LABEL: func.func @_QPloop_parallel_bind_reduction
299+
subroutine loop_parallel_bind_reduction
300+
implicit none
301+
integer :: x, i
302+
303+
! CHECK: omp.wsloop
304+
! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>)
305+
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
306+
! CHECK-NEXT: omp.loop_nest {{.*}} {
307+
! CHECK-NEXT: hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
308+
! CHECK-NEXT: hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
309+
! CHECK: }
310+
! CHECK: }
311+
!$omp loop bind(parallel) reduction(+: x)
312+
do i = 0, 10
313+
x = x + i
314+
end do
315+
end subroutine
316+
317+
! CHECK-LABEL: func.func @_QPloop_teams_loop_reduction
318+
subroutine loop_teams_loop_reduction
319+
implicit none
320+
integer :: x, i
321+
! CHECK: omp.teams {
322+
! CHECK: omp.parallel
323+
! CHECK-SAME: private(@{{[^[:space:]]+}} %{{[^[:space:]]+}}#0 -> %[[PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>) {
324+
! CHECK: omp.distribute {
325+
! CHECK: omp.wsloop
326+
! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}}#0 -> %[[RED_ARG:.*]] : !fir.ref<i32>) {
327+
! CHECK-NEXT: omp.loop_nest {{.*}} {
328+
! CHECK-NEXT: hlfir.declare %[[PRIV_ARG]] {uniq_name = "_QF{{.*}}Ei"}
329+
! CHECK-NEXT: hlfir.declare %[[RED_ARG]] {uniq_name = "_QF{{.*}}Ex"}
330+
! CHECK: }
331+
! CHECK: }
332+
! CHECK: }
333+
! CHECK: }
334+
! CHECK: }
335+
!$omp teams loop reduction(+: x)
336+
do i = 0, 10
337+
x = x + i
338+
end do
339+
end subroutine

flang/test/Transforms/generic-loop-rewriting-todo.mlir

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,12 @@
11
// RUN: fir-opt --omp-generic-loop-conversion -verify-diagnostics %s
2-
3-
omp.declare_reduction @add_reduction_i32 : i32 init {
4-
^bb0(%arg0: i32):
5-
%c0_i32 = arith.constant 0 : i32
6-
omp.yield(%c0_i32 : i32)
7-
} combiner {
8-
^bb0(%arg0: i32, %arg1: i32):
9-
%0 = arith.addi %arg0, %arg1 : i32
10-
omp.yield(%0 : i32)
11-
}
12-
132
func.func @_QPloop_order() {
143
omp.teams {
154
%c0 = arith.constant 0 : i32
165
%c10 = arith.constant 10 : i32
176
%c1 = arith.constant 1 : i32
18-
%sum = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_orderEi"}
197

20-
// expected-error@below {{not yet implemented: Unhandled clause reduction in omp.loop operation}}
21-
omp.loop reduction(@add_reduction_i32 %sum -> %arg2 : !fir.ref<i32>) {
8+
// expected-error@below {{not yet implemented: Unhandled clause order in omp.loop operation}}
9+
omp.loop order(reproducible:concurrent) {
2210
omp.loop_nest (%arg3) : i32 = (%c0) to (%c10) inclusive step (%c1) {
2311
omp.yield
2412
}

0 commit comments

Comments
 (0)