Skip to content

Commit e74f14d

Browse files
committed
[Flang] Maxloc elemental intrinsic lowering.
This is an extension to #74828 to handle maxloc too, to keep the minloc and maxloc symmetric.
1 parent 3bf21ba commit e74f14d

File tree

2 files changed

+181
-29
lines changed

2 files changed

+181
-29
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -812,54 +812,59 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
812812
// inlined elemental.
813813
// %e = hlfir.elemental %shape ({ ... })
814814
// %m = hlfir.minloc %array mask %e
815-
class MinMaxlocElementalConversion
816-
: public mlir::OpRewritePattern<hlfir::MinlocOp> {
815+
template <typename Op>
816+
class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
817817
public:
818-
using mlir::OpRewritePattern<hlfir::MinlocOp>::OpRewritePattern;
818+
using mlir::OpRewritePattern<Op>::OpRewritePattern;
819819

820820
mlir::LogicalResult
821-
matchAndRewrite(hlfir::MinlocOp minloc,
822-
mlir::PatternRewriter &rewriter) const override {
823-
if (!minloc.getMask() || minloc.getDim() || minloc.getBack())
824-
return rewriter.notifyMatchFailure(minloc, "Did not find valid minloc");
821+
matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
822+
if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
823+
return rewriter.notifyMatchFailure(mloc,
824+
"Did not find valid minloc/maxloc");
825825

826-
auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
826+
constexpr bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;
827+
828+
auto elemental =
829+
mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
827830
if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
828-
return rewriter.notifyMatchFailure(minloc, "Did not find elemental");
831+
return rewriter.notifyMatchFailure(mloc, "Did not find elemental");
829832

830-
mlir::Value array = minloc.getArray();
833+
mlir::Value array = mloc.getArray();
831834

832-
unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
835+
unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
833836
mlir::Type arrayType = array.getType();
834837
if (!arrayType.isa<fir::BoxType>())
835838
return rewriter.notifyMatchFailure(
836-
minloc, "Currently requires a boxed type input");
839+
mloc, "Currently requires a boxed type input");
837840
mlir::Type elementType = hlfir::getFortranElementType(arrayType);
838841
if (!fir::isa_trivial(elementType))
839842
return rewriter.notifyMatchFailure(
840-
minloc, "Character arrays are currently not handled");
843+
mloc, "Character arrays are currently not handled");
841844

842-
mlir::Location loc = minloc.getLoc();
843-
fir::FirOpBuilder builder{rewriter, minloc.getOperation()};
845+
mlir::Location loc = mloc.getLoc();
846+
fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
844847
mlir::Value resultArr = builder.createTemporary(
845848
loc, fir::SequenceType::get(
846-
rank, hlfir::getFortranElementType(minloc.getType())));
849+
rank, hlfir::getFortranElementType(mloc.getType())));
847850

848-
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
849-
mlir::Type elementType) {
851+
auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
852+
mlir::Type elementType) {
850853
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
851854
const llvm::fltSemantics &sem = ty.getFloatSemantics();
852855
return builder.createRealConstant(
853856
loc, elementType,
854-
llvm::APFloat::getLargest(sem, /*Negative=*/false));
857+
llvm::APFloat::getLargest(sem, /*Negative=*/!isMax));
855858
}
856859
unsigned bits = elementType.getIntOrFloatBitWidth();
857-
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
858-
return builder.createIntegerConstant(loc, elementType, maxInt);
860+
int64_t limitInt =
861+
isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
862+
: llvm::APInt::getSignedMaxValue(bits).getSExtValue();
863+
return builder.createIntegerConstant(loc, elementType, limitInt);
859864
};
860865

861866
auto genBodyOp =
862-
[&rank, &resultArr, &elemental](
867+
[&rank, &resultArr, &elemental, isMax](
863868
fir::FirOpBuilder builder, mlir::Location loc,
864869
mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
865870
mlir::Value reduction,
@@ -899,10 +904,16 @@ class MinMaxlocElementalConversion
899904
mlir::Value cmp;
900905
if (elementType.isa<mlir::FloatType>()) {
901906
cmp = builder.create<mlir::arith::CmpFOp>(
902-
loc, mlir::arith::CmpFPredicate::OLT, elem, reduction);
907+
loc,
908+
isMax ? mlir::arith::CmpFPredicate::OGT
909+
: mlir::arith::CmpFPredicate::OLT,
910+
elem, reduction);
903911
} else if (elementType.isa<mlir::IntegerType>()) {
904912
cmp = builder.create<mlir::arith::CmpIOp>(
905-
loc, mlir::arith::CmpIPredicate::slt, elem, reduction);
913+
loc,
914+
isMax ? mlir::arith::CmpIPredicate::sgt
915+
: mlir::arith::CmpIPredicate::slt,
916+
elem, reduction);
906917
} else {
907918
llvm_unreachable("unsupported type");
908919
}
@@ -975,15 +986,15 @@ class MinMaxlocElementalConversion
975986
// AsExpr for the temporary resultArr.
976987
llvm::SmallVector<hlfir::DestroyOp> destroys;
977988
llvm::SmallVector<hlfir::AssignOp> assigns;
978-
for (auto user : minloc->getUsers()) {
989+
for (auto user : mloc->getUsers()) {
979990
if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
980991
destroys.push_back(destroy);
981992
else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
982993
assigns.push_back(assign);
983994
}
984995

985-
// Check if the minloc was the only user of the elemental (apart from a
986-
// destroy), and remove it if so.
996+
// Check if the minloc/maxloc was the only user of the elemental (apart from
997+
// a destroy), and remove it if so.
987998
mlir::Operation::user_range elemUsers = elemental->getUsers();
988999
hlfir::DestroyOp elemDestroy;
9891000
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
@@ -996,7 +1007,7 @@ class MinMaxlocElementalConversion
9961007
rewriter.eraseOp(d);
9971008
for (auto a : assigns)
9981009
a.setOperand(0, resultArr);
999-
rewriter.replaceOp(minloc, asExpr);
1010+
rewriter.replaceOp(mloc, asExpr);
10001011
if (elemDestroy) {
10011012
rewriter.eraseOp(elemDestroy);
10021013
rewriter.eraseOp(elemental);
@@ -1030,7 +1041,8 @@ class OptimizedBufferizationPass
10301041
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
10311042
patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
10321043
patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
1033-
patterns.insert<MinMaxlocElementalConversion>(context);
1044+
patterns.insert<MinMaxlocElementalConversion<hlfir::MinlocOp>>(context);
1045+
patterns.insert<MinMaxlocElementalConversion<hlfir::MaxlocOp>>(context);
10341046

10351047
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
10361048
func, std::move(patterns), config))) {

flang/test/HLFIR/maxloc-elemental.fir

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// RUN: fir-opt %s -opt-bufferization | FileCheck %s
2+
3+
func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
4+
%c0 = arith.constant 0 : index
5+
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
6+
%1:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
7+
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
8+
%3 = fir.load %2#0 : !fir.ref<i32>
9+
%4:3 = fir.box_dims %0#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
10+
%5 = fir.shape %4#1 : (index) -> !fir.shape<1>
11+
%6 = hlfir.elemental %5 unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
12+
^bb0(%arg3: index):
13+
%8 = hlfir.designate %0#0 (%arg3) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
14+
%9 = fir.load %8 : !fir.ref<i32>
15+
%10 = arith.cmpi sge, %9, %3 : i32
16+
%11 = fir.convert %10 : (i1) -> !fir.logical<4>
17+
hlfir.yield_element %11 : !fir.logical<4>
18+
}
19+
%7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
20+
hlfir.assign %7 to %1#0 : !hlfir.expr<1xi32>, !fir.box<!fir.array<?xi32>>
21+
hlfir.destroy %7 : !hlfir.expr<1xi32>
22+
hlfir.destroy %6 : !hlfir.expr<?x!fir.logical<4>>
23+
return
24+
}
25+
// CHECK-LABEL: func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
26+
// CHECK-NEXT: %c-2147483648_i32 = arith.constant -2147483648 : i32
27+
// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
28+
// CHECK-NEXT: %c0 = arith.constant 0 : index
29+
// CHECK-NEXT: %c1 = arith.constant 1 : index
30+
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
31+
// CHECK-NEXT: %[[V0:.*]] = fir.alloca i32
32+
// CHECK-NEXT: %[[RES:.*]] = fir.alloca !fir.array<1xi32>
33+
// CHECK-NEXT: %[[V1:.*]]:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
34+
// CHECK-NEXT: %[[V2:.*]]:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
35+
// CHECK-NEXT: %[[V3:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
36+
// CHECK-NEXT: %[[V4:.*]] = fir.load %[[V3]]#0 : !fir.ref<i32>
37+
// CHECK-NEXT: %[[V8:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
38+
// CHECK-NEXT: fir.store %c0_i32 to %[[V8]] : !fir.ref<i32>
39+
// CHECK-NEXT: fir.store %c0_i32 to %[[V0]] : !fir.ref<i32>
40+
// CHECK-NEXT: %[[V9:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
41+
// CHECK-NEXT: %[[V10:.*]] = arith.subi %[[V9]]#1, %c1 : index
42+
// CHECK-NEXT: %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10]] step %c1 iter_args(%arg4 = %c-2147483648_i32) -> (i32) {
43+
// CHECK-NEXT: %[[V14:.*]] = arith.addi %arg3, %c1 : index
44+
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V1]]#0 (%[[V14]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
45+
// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<i32>
46+
// CHECK-NEXT: %[[V17:.*]] = arith.cmpi sge, %[[V16]], %[[V4]] : i32
47+
// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (i32) {
48+
// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
49+
// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
50+
// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
51+
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
52+
// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
53+
// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<i32>
54+
// CHECK-NEXT: %[[V21:.*]] = arith.cmpi sgt, %[[V20]], %arg4 : i32
55+
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (i32) {
56+
// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
57+
// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
58+
// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
59+
// CHECK-NEXT: fir.result %[[V20]] : i32
60+
// CHECK-NEXT: } else {
61+
// CHECK-NEXT: fir.result %arg4 : i32
62+
// CHECK-NEXT: }
63+
// CHECK-NEXT: fir.result %[[V22]] : i32
64+
// CHECK-NEXT: } else {
65+
// CHECK-NEXT: fir.result %arg4 : i32
66+
// CHECK-NEXT: }
67+
// CHECK-NEXT: fir.result %[[V18]] : i32
68+
// CHECK-NEXT: }
69+
// CHECK-NEXT: %[[V12:.*]] = fir.load %[[V0]] : !fir.ref<i32>
70+
// CHECK-NEXT: %[[V13:.*]] = arith.cmpi eq, %[[V12]], %c1_i32 : i32
71+
// CHECK-NEXT: fir.if %[[V13]] {
72+
// CHECK-NEXT: %[[V14:.*]] = arith.cmpi eq, %[[V11]], %c-2147483648_i32 : i32
73+
// CHECK-NEXT: fir.if %[[V14]] {
74+
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
75+
// CHECK-NEXT: fir.store %c1_i32 to %[[V15]] : !fir.ref<i32>
76+
// CHECK-NEXT: }
77+
// CHECK-NEXT: }
78+
// CHECK-NEXT: %[[BD:.*]]:3 = fir.box_dims %[[V2]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
79+
// CHECK-NEXT: fir.do_loop %arg3 = %c1 to %[[BD]]#1 step %c1 unordered {
80+
// CHECK-NEXT: %[[V13:.*]] = hlfir.designate %[[RES]] (%arg3) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
81+
// CHECK-NEXT: %[[V14:.*]] = fir.load %[[V13]] : !fir.ref<i32>
82+
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V2]]#0 (%arg3) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
83+
// CHECK-NEXT: hlfir.assign %[[V14]] to %[[V15]] : i32, !fir.ref<i32>
84+
// CHECK-NEXT: }
85+
// CHECK-NEXT: return
86+
// CHECK-NEXT: }
87+
88+
89+
90+
func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<f32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
91+
%c0 = arith.constant 0 : index
92+
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box<!fir.array<?xf32>>) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
93+
%1:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
94+
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
95+
%3 = fir.load %2#0 : !fir.ref<f32>
96+
%4:3 = fir.box_dims %0#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
97+
%5 = fir.shape %4#1 : (index) -> !fir.shape<1>
98+
%6 = hlfir.elemental %5 unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
99+
^bb0(%arg3: index):
100+
%8 = hlfir.designate %0#0 (%arg3) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
101+
%9 = fir.load %8 : !fir.ref<f32>
102+
%10 = arith.cmpf oge, %9, %3 : f32
103+
%11 = fir.convert %10 : (i1) -> !fir.logical<4>
104+
hlfir.yield_element %11 : !fir.logical<4>
105+
}
106+
%7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xf32>>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
107+
hlfir.assign %7 to %1#0 : !hlfir.expr<1xi32>, !fir.box<!fir.array<?xi32>>
108+
hlfir.destroy %7 : !hlfir.expr<1xi32>
109+
hlfir.destroy %6 : !hlfir.expr<?x!fir.logical<4>>
110+
return
111+
}
112+
// CHECK-LABEL: _QPtest_float
113+
// CHECK: %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10:.*]] step %c1 iter_args(%arg4 = %cst) -> (f32) {
114+
// CHECK-NEXT: %[[V14:.*]] = arith.addi %arg3, %c1 : index
115+
// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V1:.*]]#0 (%[[V14]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
116+
// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<f32>
117+
// CHECK-NEXT: %[[V17:.*]] = arith.cmpf oge, %[[V16]], %[[V4:.*]] : f32
118+
// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (f32) {
119+
// CHECK-NEXT: fir.store %c1_i32 to %[[V0:.*]] : !fir.ref<i32>
120+
// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %2#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
121+
// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
122+
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
123+
// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
124+
// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<f32>
125+
// CHECK-NEXT: %[[V21:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath<contract> : f32
126+
// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (f32) {
127+
// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %{{.}} (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
128+
// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
129+
// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
130+
// CHECK-NEXT: fir.result %[[V20]] : f32
131+
// CHECK-NEXT: } else {
132+
// CHECK-NEXT: fir.result %arg4 : f32
133+
// CHECK-NEXT: }
134+
// CHECK-NEXT: fir.result %[[V22]] : f32
135+
// CHECK-NEXT: } else {
136+
// CHECK-NEXT: fir.result %arg4 : f32
137+
// CHECK-NEXT: }
138+
// CHECK-NEXT: fir.result %[[V18]] : f32
139+
// CHECK-NEXT: }
140+

0 commit comments

Comments
 (0)