Skip to content

Commit dbd6eb6

Browse files
authored
[flang][OpenMP] lower reductions of assumed shape arrays (#86982)
Patch 1: #86978 Patch 2: #86979
1 parent 2d00874 commit dbd6eb6

File tree

2 files changed

+112
-3
lines changed

2 files changed

+112
-3
lines changed

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,16 @@ void ReductionProcessor::addDeclareReduction(
523523
if (reductionSymbols)
524524
reductionSymbols->push_back(symbol);
525525
mlir::Value symVal = converter.getSymbolAddress(*symbol);
526-
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
526+
mlir::Type eleType;
527+
auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
528+
if (refType)
529+
eleType = refType.getEleTy();
530+
else
531+
eleType = symVal.getType();
527532

528533
// all arrays must be boxed so that we have convenient access to all the
529534
// information needed to iterate over the array
530-
if (mlir::isa<fir::SequenceType>(redType.getEleTy())) {
535+
if (mlir::isa<fir::SequenceType>(eleType)) {
531536
// For Host associated symbols, use `SymbolBox` instead
532537
Fortran::lower::SymbolBox symBox =
533538
converter.lookupOneLevelUpSymbol(*symbol);
@@ -542,11 +547,25 @@ void ReductionProcessor::addDeclareReduction(
542547
builder.create<fir::StoreOp>(currentLocation, box, alloca);
543548

544549
symVal = alloca;
545-
redType = mlir::cast<fir::ReferenceType>(symVal.getType());
550+
} else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
551+
// boxed arrays are passed as values not by reference. Unfortunately,
552+
// we can't pass a box by value to omp.redution_declare, so turn it
553+
// into a reference
554+
555+
auto alloca =
556+
builder.create<fir::AllocaOp>(currentLocation, symVal.getType());
557+
builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
558+
symVal = alloca;
546559
} else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
547560
symVal = declOp.getBase();
548561
}
549562

563+
// this isn't the same as the by-val and by-ref passing later in the
564+
// pipeline. Both styles assume that the variable is a reference at
565+
// this point
566+
assert(mlir::isa<fir::ReferenceType>(symVal.getType()) &&
567+
"reduction input var is a reference");
568+
550569
reductionVars.push_back(symVal);
551570
}
552571
const bool isByRef = doReductionByRef(reductionVars);
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
! RUN: bbc -emit-hlfir -fopenmp -o - %s | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s | FileCheck %s
3+
4+
program reduce_assumed_shape
5+
real(8), dimension(2) :: r
6+
r = 0
7+
call reduce(r)
8+
print *, r
9+
10+
contains
11+
subroutine reduce(r)
12+
implicit none
13+
real(8),intent(inout) :: r(:)
14+
integer :: i = 0
15+
16+
!$omp parallel do reduction(+:r)
17+
do i=0,10
18+
r(1) = i
19+
r(2) = 1
20+
enddo
21+
!$omp end parallel do
22+
end subroutine
23+
end program
24+
25+
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_Uxf64 : !fir.ref<!fir.box<!fir.array<?xf64>>> init {
26+
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xf64>>>):
27+
! CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f64
28+
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
29+
! CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
30+
! CHECK: %[[VAL_4:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
31+
! CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]]#1 : (index) -> !fir.shape<1>
32+
! CHECK: %[[VAL_6:.*]] = fir.alloca !fir.array<?xf64>, %[[VAL_4]]#1 {bindc_name = ".tmp"}
33+
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.ref<!fir.array<?xf64>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf64>>, !fir.ref<!fir.array<?xf64>>)
34+
! CHECK: hlfir.assign %[[VAL_1]] to %[[VAL_7]]#0 : f64, !fir.box<!fir.array<?xf64>>
35+
! CHECK: %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<?xf64>>
36+
! CHECK: fir.store %[[VAL_7]]#0 to %[[VAL_8]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
37+
! CHECK: omp.yield(%[[VAL_8]] : !fir.ref<!fir.box<!fir.array<?xf64>>>)
38+
39+
! CHECK-LABEL: } combiner {
40+
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xf64>>>, %[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.array<?xf64>>>):
41+
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
42+
! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
43+
! CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
44+
! CHECK: %[[VAL_5:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
45+
! CHECK: %[[VAL_6:.*]] = fir.shape_shift %[[VAL_5]]#0, %[[VAL_5]]#1 : (index, index) -> !fir.shapeshift<1>
46+
! CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
47+
! CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_7]] to %[[VAL_5]]#1 step %[[VAL_7]] unordered {
48+
! CHECK: %[[VAL_9:.*]] = fir.array_coor %[[VAL_2]](%[[VAL_6]]) %[[VAL_8]] : (!fir.box<!fir.array<?xf64>>, !fir.shapeshift<1>, index) -> !fir.ref<f64>
49+
! CHECK: %[[VAL_10:.*]] = fir.array_coor %[[VAL_3]](%[[VAL_6]]) %[[VAL_8]] : (!fir.box<!fir.array<?xf64>>, !fir.shapeshift<1>, index) -> !fir.ref<f64>
50+
! CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_9]] : !fir.ref<f64>
51+
! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_10]] : !fir.ref<f64>
52+
! CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] fastmath<contract> : f64
53+
! CHECK: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<f64>
54+
! CHECK: }
55+
! CHECK: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xf64>>>)
56+
! CHECK: }
57+
58+
! CHECK-LABEL: func.func private @_QFPreduce(
59+
! CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "r"}) attributes {{.*}} {
60+
! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFFreduceEi) : !fir.ref<i32>
61+
! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFFreduceEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
62+
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]] {fortran_attrs = {{.*}}, uniq_name = "_QFFreduceEr"} : (!fir.box<!fir.array<?xf64>>) -> (!fir.box<!fir.array<?xf64>>, !fir.box<!fir.array<?xf64>>)
63+
! CHECK: omp.parallel {
64+
! CHECK: %[[VAL_4:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
65+
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = "_QFFreduceEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
66+
! CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
67+
! CHECK: %[[VAL_7:.*]] = arith.constant 10 : i32
68+
! CHECK: %[[VAL_8:.*]] = arith.constant 1 : i32
69+
! CHECK: %[[VAL_9:.*]] = fir.alloca !fir.box<!fir.array<?xf64>>
70+
! CHECK: fir.store %[[VAL_3]]#1 to %[[VAL_9]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
71+
! CHECK: omp.wsloop byref reduction(@add_reduction_byref_box_Uxf64 %[[VAL_9]] -> %[[VAL_10:.*]] : !fir.ref<!fir.box<!fir.array<?xf64>>>) for (%[[VAL_11:.*]]) : i32 = (%[[VAL_6]]) to (%[[VAL_7]]) inclusive step (%[[VAL_8]]) {
72+
! CHECK: fir.store %[[VAL_11]] to %[[VAL_5]]#1 : !fir.ref<i32>
73+
! CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_10]] {fortran_attrs = {{.*}}, uniq_name = "_QFFreduceEr"} : (!fir.ref<!fir.box<!fir.array<?xf64>>>) -> (!fir.ref<!fir.box<!fir.array<?xf64>>>, !fir.ref<!fir.box<!fir.array<?xf64>>>)
74+
! CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
75+
! CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (i32) -> f64
76+
! CHECK: %[[VAL_15:.*]] = fir.load %[[VAL_12]]#0 : !fir.ref<!fir.box<!fir.array<?xf64>>>
77+
! CHECK: %[[VAL_16:.*]] = arith.constant 1 : index
78+
! CHECK: %[[VAL_17:.*]] = hlfir.designate %[[VAL_15]] (%[[VAL_16]]) : (!fir.box<!fir.array<?xf64>>, index) -> !fir.ref<f64>
79+
! CHECK: hlfir.assign %[[VAL_14]] to %[[VAL_17]] : f64, !fir.ref<f64>
80+
! CHECK: %[[VAL_18:.*]] = arith.constant 1.000000e+00 : f64
81+
! CHECK: %[[VAL_19:.*]] = fir.load %[[VAL_12]]#0 : !fir.ref<!fir.box<!fir.array<?xf64>>>
82+
! CHECK: %[[VAL_20:.*]] = arith.constant 2 : index
83+
! CHECK: %[[VAL_21:.*]] = hlfir.designate %[[VAL_19]] (%[[VAL_20]]) : (!fir.box<!fir.array<?xf64>>, index) -> !fir.ref<f64>
84+
! CHECK: hlfir.assign %[[VAL_18]] to %[[VAL_21]] : f64, !fir.ref<f64>
85+
! CHECK: omp.yield
86+
! CHECK: }
87+
! CHECK: omp.terminator
88+
! CHECK: }
89+
! CHECK: return
90+
! CHECK: }

0 commit comments

Comments
 (0)