Skip to content

Commit d9dfde4

Browse files
tblahAlexisPerry
authored andcommitted
[flang][OpenMP] support more reduction types for procedure designators (llvm#96057)
This re-uses reduction declarations from intrinsic operators to add support for reductions of allocatables, pointers, and arrays with procedure designators (e.g. min/max). I have split this into two commits to make it easier to review. The first one makes the functional change. The second cleans things up now that we can share much more code between intrinsic operators and procedure designators.
1 parent 071e823 commit d9dfde4

File tree

4 files changed

+444
-68
lines changed

4 files changed

+444
-68
lines changed

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -709,8 +709,8 @@ void ReductionProcessor::addDeclareReduction(
709709
}
710710
}
711711

712-
// initial pass to collect all reduction vars so we can figure out if this
713-
// should happen byref
712+
// Reduction variable processing common to both intrinsic operators and
713+
// procedure designators
714714
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
715715
for (const Object &object : objectList) {
716716
const semantics::Symbol *symbol = object.sym();
@@ -763,64 +763,56 @@ void ReductionProcessor::addDeclareReduction(
763763
reduceVarByRef.push_back(doReductionByRef(symVal));
764764
}
765765

766-
if (const auto &redDefinedOp =
767-
std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
768-
const auto &intrinsicOp{
769-
std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
770-
redDefinedOp->u)};
771-
ReductionIdentifier redId = getReductionType(intrinsicOp);
772-
switch (redId) {
773-
case ReductionIdentifier::ADD:
774-
case ReductionIdentifier::MULTIPLY:
775-
case ReductionIdentifier::AND:
776-
case ReductionIdentifier::EQV:
777-
case ReductionIdentifier::OR:
778-
case ReductionIdentifier::NEQV:
779-
break;
780-
default:
781-
TODO(currentLocation,
782-
"Reduction of some intrinsic operators is not supported");
783-
break;
784-
}
766+
for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
767+
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
768+
const auto &kindMap = firOpBuilder.getKindMap();
769+
std::string reductionName;
770+
ReductionIdentifier redId;
771+
mlir::Type redNameTy = redType;
772+
if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
773+
redNameTy = builder.getI1Type();
774+
775+
if (const auto &redDefinedOp =
776+
std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
777+
const auto &intrinsicOp{
778+
std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
779+
redDefinedOp->u)};
780+
redId = getReductionType(intrinsicOp);
781+
switch (redId) {
782+
case ReductionIdentifier::ADD:
783+
case ReductionIdentifier::MULTIPLY:
784+
case ReductionIdentifier::AND:
785+
case ReductionIdentifier::EQV:
786+
case ReductionIdentifier::OR:
787+
case ReductionIdentifier::NEQV:
788+
break;
789+
default:
790+
TODO(currentLocation,
791+
"Reduction of some intrinsic operators is not supported");
792+
break;
793+
}
785794

786-
for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
787-
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
788-
const auto &kindMap = firOpBuilder.getKindMap();
789-
if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
790-
decl = createDeclareReduction(firOpBuilder,
791-
getReductionName(intrinsicOp, kindMap,
792-
firOpBuilder.getI1Type(),
793-
isByRef),
794-
redId, redType, currentLocation, isByRef);
795-
else
796-
decl = createDeclareReduction(
797-
firOpBuilder,
798-
getReductionName(intrinsicOp, kindMap, redType, isByRef), redId,
799-
redType, currentLocation, isByRef);
800-
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
801-
firOpBuilder.getContext(), decl.getSymName()));
802-
}
803-
} else if (const auto *reductionIntrinsic =
804-
std::get_if<omp::clause::ProcedureDesignator>(
805-
&redOperator.u)) {
806-
if (ReductionProcessor::supportedIntrinsicProcReduction(
807-
*reductionIntrinsic)) {
808-
ReductionProcessor::ReductionIdentifier redId =
809-
ReductionProcessor::getReductionType(*reductionIntrinsic);
810-
for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
811-
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
812-
if (!redType.getEleTy().isIntOrIndexOrFloat())
813-
TODO(currentLocation,
814-
"Reduction of some types is not supported for intrinsics");
815-
decl = createDeclareReduction(
816-
firOpBuilder,
817-
getReductionName(getRealName(*reductionIntrinsic).ToString(),
818-
firOpBuilder.getKindMap(), redType, isByRef),
819-
redId, redType, currentLocation, isByRef);
820-
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
821-
firOpBuilder.getContext(), decl.getSymName()));
795+
reductionName =
796+
getReductionName(intrinsicOp, kindMap, redNameTy, isByRef);
797+
} else if (const auto *reductionIntrinsic =
798+
std::get_if<omp::clause::ProcedureDesignator>(
799+
&redOperator.u)) {
800+
if (!ReductionProcessor::supportedIntrinsicProcReduction(
801+
*reductionIntrinsic)) {
802+
TODO(currentLocation, "Unsupported intrinsic proc reduction");
822803
}
804+
redId = getReductionType(*reductionIntrinsic);
805+
reductionName =
806+
getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap,
807+
redNameTy, isByRef);
808+
} else {
809+
TODO(currentLocation, "Unexpected reduction type");
823810
}
811+
812+
decl = createDeclareReduction(firOpBuilder, reductionName, redId, redType,
813+
currentLocation, isByRef);
814+
reductionDeclSymbols.push_back(
815+
mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
824816
}
825817
}
826818

flang/test/Lower/OpenMP/Todo/reduction-array-intrinsic.f90

Lines changed: 0 additions & 11 deletions
This file was deleted.
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
3+
4+
subroutine max_array_reduction(l, r)
5+
integer :: l(:), r(:)
6+
7+
!$omp parallel reduction(max:l)
8+
l = max(l, r)
9+
!$omp end parallel
10+
end subroutine
11+
12+
! CHECK-LABEL: omp.declare_reduction @max_byref_box_Uxi32 : !fir.ref<!fir.box<!fir.array<?xi32>>> init {
13+
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>):
14+
! CHECK: %[[VAL_1:.*]] = arith.constant -2147483648 : i32
15+
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
16+
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
17+
! CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
18+
! CHECK: %[[VAL_5:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
19+
! CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]]#1 : (index) -> !fir.shape<1>
20+
! CHECK: %[[VAL_7:.*]] = fir.allocmem !fir.array<?xi32>, %[[VAL_5]]#1 {bindc_name = ".tmp", uniq_name = ""}
21+
! CHECK: %[[VAL_8:.*]] = arith.constant true
22+
! CHECK: %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_7]](%[[VAL_6]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.heap<!fir.array<?xi32>>)
23+
! CHECK: %[[VAL_10:.*]] = arith.constant 0 : index
24+
! CHECK: %[[VAL_11:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_10]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
25+
! CHECK: %[[VAL_12:.*]] = fir.shape_shift %[[VAL_11]]#0, %[[VAL_11]]#1 : (index, index) -> !fir.shapeshift<1>
26+
! CHECK: %[[VAL_13:.*]] = fir.rebox %[[VAL_9]]#0(%[[VAL_12]]) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<?xi32>>
27+
! CHECK: hlfir.assign %[[VAL_1]] to %[[VAL_13]] : i32, !fir.box<!fir.array<?xi32>>
28+
! CHECK: fir.store %[[VAL_13]] to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
29+
! CHECK: omp.yield(%[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xi32>>>)
30+
! CHECK-LABEL: } combiner {
31+
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>, %[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>):
32+
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
33+
! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
34+
! CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
35+
! CHECK: %[[VAL_5:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
36+
! CHECK: %[[VAL_6:.*]] = fir.shape_shift %[[VAL_5]]#0, %[[VAL_5]]#1 : (index, index) -> !fir.shapeshift<1>
37+
! CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
38+
! CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_7]] to %[[VAL_5]]#1 step %[[VAL_7]] unordered {
39+
! CHECK: %[[VAL_9:.*]] = fir.array_coor %[[VAL_2]](%[[VAL_6]]) %[[VAL_8]] : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
40+
! CHECK: %[[VAL_10:.*]] = fir.array_coor %[[VAL_3]](%[[VAL_6]]) %[[VAL_8]] : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
41+
! CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_9]] : !fir.ref<i32>
42+
! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
43+
! CHECK: %[[VAL_13:.*]] = arith.maxsi %[[VAL_11]], %[[VAL_12]] : i32
44+
! CHECK: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
45+
! CHECK: }
46+
! CHECK: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>)
47+
! CHECK-LABEL: } cleanup {
48+
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<?xi32>>>):
49+
! CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
50+
! CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
51+
! CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xi32>>) -> i64
52+
! CHECK: %[[VAL_4:.*]] = arith.constant 0 : i64
53+
! CHECK: %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
54+
! CHECK: fir.if %[[VAL_5]] {
55+
! CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xi32>>) -> !fir.heap<!fir.array<?xi32>>
56+
! CHECK: fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<?xi32>>
57+
! CHECK: }
58+
! CHECK: omp.yield
59+
! CHECK: }
60+
61+
! CHECK-LABEL: func.func @_QPmax_array_reduction(
62+
! CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "l"},
63+
! CHECK-SAME: %[[VAL_1:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "r"}) {
64+
! CHECK: %[[VAL_2:.*]] = fir.dummy_scope : !fir.dscope
65+
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_2]] {uniq_name = "_QFmax_array_reductionEl"} : (!fir.box<!fir.array<?xi32>>, !fir.dscope) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
66+
! CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_1]] dummy_scope %[[VAL_2]] {uniq_name = "_QFmax_array_reductionEr"} : (!fir.box<!fir.array<?xi32>>, !fir.dscope) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
67+
! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<?xi32>>
68+
! CHECK: fir.store %[[VAL_3]]#1 to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<?xi32>>>
69+
! CHECK: omp.parallel reduction(byref @max_byref_box_Uxi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<?xi32>>>) {
70+
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFmax_array_reductionEl"} : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> (!fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.array<?xi32>>>)
71+
! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<?xi32>>>
72+
! CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
73+
! CHECK: %[[VAL_10:.*]]:3 = fir.box_dims %[[VAL_8]], %[[VAL_9]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
74+
! CHECK: %[[VAL_11:.*]] = fir.shape %[[VAL_10]]#1 : (index) -> !fir.shape<1>
75+
! CHECK: %[[VAL_12:.*]] = hlfir.elemental %[[VAL_11]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
76+
! CHECK: ^bb0(%[[VAL_13:.*]]: index):
77+
! CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
78+
! CHECK: %[[VAL_15:.*]]:3 = fir.box_dims %[[VAL_8]], %[[VAL_14]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
79+
! CHECK: %[[VAL_16:.*]] = arith.constant 1 : index
80+
! CHECK: %[[VAL_17:.*]] = arith.subi %[[VAL_15]]#0, %[[VAL_16]] : index
81+
! CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_13]], %[[VAL_17]] : index
82+
! CHECK: %[[VAL_19:.*]] = hlfir.designate %[[VAL_8]] (%[[VAL_18]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
83+
! CHECK: %[[VAL_20:.*]] = fir.load %[[VAL_19]] : !fir.ref<i32>
84+
! CHECK: %[[VAL_21:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_13]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
85+
! CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
86+
! CHECK: %[[VAL_23:.*]] = arith.cmpi sgt, %[[VAL_20]], %[[VAL_22]] : i32
87+
! CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[VAL_20]], %[[VAL_22]] : i32
88+
! CHECK: hlfir.yield_element %[[VAL_24]] : i32
89+
! CHECK: }
90+
! CHECK: %[[VAL_25:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<?xi32>>>
91+
! CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_25]] : !hlfir.expr<?xi32>, !fir.box<!fir.array<?xi32>>
92+
! CHECK: hlfir.destroy %[[VAL_12]] : !hlfir.expr<?xi32>
93+
! CHECK: omp.terminator
94+
! CHECK: }
95+
! CHECK: return
96+
! CHECK: }

0 commit comments

Comments
 (0)